1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/IR/Tensor.h"
10 #include "mlir/IR/BlockAndValueMapping.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/IR/TypeUtilities.h"
15 #include "llvm/ADT/STLExtras.h"
16 
17 using namespace mlir;
18 using namespace mlir::tensor;
19 
20 //===----------------------------------------------------------------------===//
21 // CastOp
22 //===----------------------------------------------------------------------===//
23 
24 /// Determines whether tensor::CastOp casts to a more dynamic version of the
25 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
26 /// implement canonicalization patterns for ops in different dialects that may
27 /// consume the results of tensor.cast operations. Such foldable tensor.cast
28 /// operations are typically inserted as `subtensor` ops and are canonicalized,
29 /// to preserve the type compatibility of their uses.
30 ///
31 /// Returns true when all conditions are met:
32 /// 1. source and result are ranked tensors with same element type and rank.
33 /// 2. the tensor type has more static information than the result
34 ///
35 /// Example:
36 /// ```mlir
37 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
38 ///   %2 = consumer %1 ... : tensor<?x?xf32> ...
39 /// ```
40 ///
41 /// folds into:
42 ///
43 /// ```mlir
44 ///   %2 = consumer %0 ... : tensor<8x16xf32> ...
45 /// ```
46 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
47   if (!castOp)
48     return false;
49 
50   RankedTensorType sourceType =
51       castOp.source().getType().dyn_cast<RankedTensorType>();
52   RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
53 
54   // Requires RankedTensorType.
55   if (!sourceType || !resultType)
56     return false;
57 
58   // Requires same elemental type.
59   if (sourceType.getElementType() != resultType.getElementType())
60     return false;
61 
62   // Requires same rank.
63   if (sourceType.getRank() != resultType.getRank())
64     return false;
65 
66   // If cast is towards more static sizes along any dimension, don't fold.
67   for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
68     if (ShapedType::isDynamic(std::get<0>(t)) &&
69         !ShapedType::isDynamic(std::get<1>(t)))
70       return false;
71   }
72 
73   return true;
74 }
75 
76 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
77   if (inputs.size() != 1 || outputs.size() != 1)
78     return false;
79   Type a = inputs.front(), b = outputs.front();
80   auto aT = a.dyn_cast<TensorType>();
81   auto bT = b.dyn_cast<TensorType>();
82   if (!aT || !bT)
83     return false;
84 
85   if (aT.getElementType() != bT.getElementType())
86     return false;
87 
88   return succeeded(verifyCompatibleShape(aT, bT));
89 }
90 
91 /// Compute a TensorType that has the joined shape knowledge of the two
92 /// given TensorTypes. The element types need to match.
93 static TensorType joinShapes(TensorType one, TensorType two) {
94   assert(one.getElementType() == two.getElementType());
95 
96   if (!one.hasRank())
97     return two;
98   if (!two.hasRank())
99     return one;
100 
101   int64_t rank = one.getRank();
102   if (rank != two.getRank())
103     return {};
104 
105   SmallVector<int64_t, 4> join;
106   join.reserve(rank);
107   for (int64_t i = 0; i < rank; ++i) {
108     if (one.isDynamicDim(i)) {
109       join.push_back(two.getDimSize(i));
110       continue;
111     }
112     if (two.isDynamicDim(i)) {
113       join.push_back(one.getDimSize(i));
114       continue;
115     }
116     if (one.getDimSize(i) != two.getDimSize(i))
117       return {};
118     join.push_back(one.getDimSize(i));
119   }
120   return RankedTensorType::get(join, one.getElementType());
121 }
122 
123 namespace {
124 
125 /// Replaces chains of two tensor.cast operations by a single tensor.cast
126 /// operation if doing so does not remove runtime constraints.
127 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
128   using OpRewritePattern<CastOp>::OpRewritePattern;
129 
130   LogicalResult matchAndRewrite(CastOp tensorCast,
131                                 PatternRewriter &rewriter) const final {
132     auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
133 
134     if (!tensorCastOperand)
135       return failure();
136 
137     auto sourceType =
138         tensorCastOperand.getOperand().getType().cast<TensorType>();
139     auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
140     auto resultType = tensorCast.getType().cast<TensorType>();
141 
142     // We can remove the intermediate cast if joining all three produces the
143     // same result as just joining the source and result shapes.
144     auto firstJoin =
145         joinShapes(joinShapes(sourceType, intermediateType), resultType);
146 
147     // The join might not exist if the cast sequence would fail at runtime.
148     if (!firstJoin)
149       return failure();
150 
151     // The newJoin always exists if the above join exists, it might just contain
152     // less information. If so, we cannot drop the intermediate cast, as doing
153     // so would remove runtime checks.
154     auto newJoin = joinShapes(sourceType, resultType);
155     if (firstJoin != newJoin)
156       return failure();
157 
158     rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
159                                         tensorCastOperand.getOperand());
160     return success();
161   }
162 };
163 
164 } // namespace
165 
166 void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
167                                          MLIRContext *context) {
168   results.insert<ChainedTensorCast>(context);
169 }
170 
171 //===----------------------------------------------------------------------===//
172 // ExtractOp
173 //===----------------------------------------------------------------------===//
174 
175 static LogicalResult verify(ExtractOp op) {
176   // Verify the # indices match if we have a ranked type.
177   if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
178     if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
179       return op.emitOpError("incorrect number of indices for extract_element");
180 
181   return success();
182 }
183 
184 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
185   // The tensor operand must be a known constant.
186   Attribute tensor = operands.front();
187   if (!tensor)
188     return {};
189   // If this is a splat elements attribute, simply return the value. All of the
190   // elements of a splat attribute are the same.
191   if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
192     return splatTensor.getSplatValue();
193 
194   // Otherwise, collect the constant indices into the tensor.
195   SmallVector<uint64_t, 8> indices;
196   for (Attribute indice : llvm::drop_begin(operands, 1)) {
197     if (!indice || !indice.isa<IntegerAttr>())
198       return {};
199     indices.push_back(indice.cast<IntegerAttr>().getInt());
200   }
201 
202   // If this is an elements attribute, query the value at the given indices.
203   auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
204   if (elementsAttr && elementsAttr.isValidIndex(indices))
205     return elementsAttr.getValue(indices);
206   return {};
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // FromElementsOp
211 //===----------------------------------------------------------------------===//
212 
213 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
214                            Type elementType, ValueRange elements) {
215   Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
216                                         elementType);
217   result.addOperands(elements);
218   result.addTypes(resultTy);
219 }
220 
221 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
222                            ValueRange elements) {
223   assert(!elements.empty() && "expected at least one element");
224   build(builder, result, elements.front().getType(), elements);
225 }
226 
227 namespace {
228 
229 // Canonicalizes the pattern of the form
230 //
231 // %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
232 // %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
233 //
234 // to just %element.
235 struct ExtractElementFromTensorFromElements
236     : public OpRewritePattern<tensor::ExtractOp> {
237   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
238 
239   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
240                                 PatternRewriter &rewriter) const final {
241     if (extract.indices().size() != 1)
242       return failure();
243 
244     auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
245     if (tensorFromElements == nullptr)
246       return failure();
247 
248     APInt index;
249     if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
250       return failure();
251     rewriter.replaceOp(extract,
252                        tensorFromElements.getOperand(index.getZExtValue()));
253     return success();
254   }
255 };
256 
257 } // namespace
258 
259 void FromElementsOp::getCanonicalizationPatterns(
260     OwningRewritePatternList &results, MLIRContext *context) {
261   results.insert<ExtractElementFromTensorFromElements>(context);
262 }
263 
264 //===----------------------------------------------------------------------===//
265 // GenerateOp
266 //===----------------------------------------------------------------------===//
267 
268 static LogicalResult verify(GenerateOp op) {
269   // Ensure that the tensor type has as many dynamic dimensions as are specified
270   // by the operands.
271   RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
272   if (op.getNumOperands() != resultTy.getNumDynamicDims())
273     return op.emitError("must have as many index operands as dynamic extents "
274                         "in the result type");
275 
276   // Ensure that region arguments span the index space.
277   if (!llvm::all_of(op.body().getArgumentTypes(),
278                     [](Type ty) { return ty.isIndex(); }))
279     return op.emitError("all body arguments must be index");
280   if (op.body().getNumArguments() != resultTy.getRank())
281     return op.emitError("must have one body argument per input dimension");
282 
283   // Ensure that the region yields an element of the right type.
284   auto yieldOp =
285       llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
286   if (yieldOp.value().getType() != resultTy.getElementType())
287     return op.emitOpError(
288         "body must be terminated with a `yield` operation of the tensor "
289         "element type");
290 
291   return success();
292 }
293 
294 void GenerateOp::build(
295     OpBuilder &b, OperationState &result, Type resultTy,
296     ValueRange dynamicExtents,
297     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
298   build(b, result, resultTy, dynamicExtents);
299 
300   // Build and populate body.
301   OpBuilder::InsertionGuard guard(b);
302   Region *bodyRegion = result.regions.front().get();
303   auto rank = resultTy.cast<RankedTensorType>().getRank();
304   SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
305   Block *bodyBlock =
306       b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes);
307   bodyBuilder(b, result.location, bodyBlock->getArguments());
308 }
309 
310 namespace {
311 
312 /// Canonicalizes tensor.generate operations with a constant
313 /// operand into the equivalent operation with the operand expressed in the
314 /// result type, instead. We also insert a type cast to make sure that the
315 /// resulting IR is still well-typed.
316 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
317   using OpRewritePattern<GenerateOp>::OpRewritePattern;
318 
319   LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
320                                 PatternRewriter &rewriter) const final {
321     auto resultType =
322         tensorFromElements.getResult().getType().cast<RankedTensorType>();
323 
324     if (resultType.hasStaticShape())
325       return failure();
326 
327     SmallVector<Value, 4> newOperands;
328     SmallVector<int64_t, 4> newShape;
329     auto operandsIt = tensorFromElements.dynamicExtents().begin();
330 
331     for (int64_t dim : resultType.getShape()) {
332       if (dim != RankedTensorType::kDynamicSize) {
333         newShape.push_back(dim);
334         continue;
335       }
336       APInt index;
337       if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
338         newShape.push_back(RankedTensorType::kDynamicSize);
339         newOperands.push_back(*operandsIt++);
340         continue;
341       }
342       newShape.push_back(index.getSExtValue());
343       operandsIt++;
344     }
345 
346     if (newOperands.size() == tensorFromElements.dynamicExtents().size())
347       return failure();
348 
349     auto loc = tensorFromElements.getLoc();
350     auto newOp = rewriter.create<GenerateOp>(
351         loc, RankedTensorType::get(newShape, resultType.getElementType()),
352         newOperands);
353     rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
354                                 newOp.body().begin());
355     rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
356                                                 newOp);
357     return success();
358   }
359 };
360 
361 /// Canonicalizes the pattern of the form
362 ///
363 /// %tensor = tensor.generate %x {
364 ///   ^bb0(%arg0: index):  // no predecessors
365 ///   <computation>
366 ///   yield %1 : index
367 /// } : tensor<?xindex>
368 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
369 ///
370 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
371 /// tensor.generate operation has no side-effects.
372 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
373   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
374 
375   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
376                                 PatternRewriter &rewriter) const final {
377     auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>();
378     if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
379       return failure();
380 
381     BlockAndValueMapping mapping;
382     Block *body = tensorFromElements.getBody();
383     mapping.map(body->getArguments(), extract.indices());
384     for (auto &op : body->without_terminator())
385       rewriter.clone(op, mapping);
386 
387     auto yield = cast<YieldOp>(body->getTerminator());
388 
389     rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
390     return success();
391   }
392 };
393 
394 /// Canonicalizes the pattern of the form
395 ///
396 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
397 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
398 ///
399 /// to
400 ///
401 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
402 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
403   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
404 
405   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
406                                 PatternRewriter &rewriter) const final {
407     auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
408     if (!tensorCast)
409       return failure();
410 
411     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
412                                                    extract.indices());
413     return success();
414   }
415 };
416 
417 } // namespace
418 
419 void GenerateOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
420                                              MLIRContext *context) {
421   // TODO: Move extract patterns to tensor::ExtractOp.
422   results.insert<ExtractFromTensorGenerate, ExtractFromTensorCast,
423                  StaticTensorGenerate>(context);
424 }
425 
426 //===----------------------------------------------------------------------===//
427 // TableGen'd op method definitions
428 //===----------------------------------------------------------------------===//
429 
430 #define GET_OP_CLASSES
431 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
432