1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/IR/Tensor.h"
10 #include "mlir/IR/BlockAndValueMapping.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/IR/TypeUtilities.h"
15 #include "llvm/ADT/STLExtras.h"
16 
17 using namespace mlir;
18 using namespace mlir::tensor;
19 
20 //===----------------------------------------------------------------------===//
21 // CastOp
22 //===----------------------------------------------------------------------===//
23 
24 /// Determines whether tensor::CastOp casts to a more dynamic version of the
25 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
26 /// implement canonicalization patterns for ops in different dialects that may
27 /// consume the results of tensor.cast operations. Such foldable tensor.cast
28 /// operations are typically inserted as `subtensor` ops and are canonicalized,
29 /// to preserve the type compatibility of their uses.
30 ///
31 /// Returns true when all conditions are met:
32 /// 1. source and result are ranked tensors with same element type and rank.
33 /// 2. the tensor type has more static information than the result
34 ///
35 /// Example:
36 /// ```mlir
37 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
38 ///   %2 = consumer %1 ... : tensor<?x?xf32> ...
39 /// ```
40 ///
41 /// folds into:
42 ///
43 /// ```mlir
44 ///   %2 = consumer %0 ... : tensor<8x16xf32> ...
45 /// ```
46 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
47   if (!castOp)
48     return false;
49 
50   RankedTensorType sourceType =
51       castOp.source().getType().dyn_cast<RankedTensorType>();
52   RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
53 
54   // Requires RankedTensorType.
55   if (!sourceType || !resultType)
56     return false;
57 
58   // Requires same elemental type.
59   if (sourceType.getElementType() != resultType.getElementType())
60     return false;
61 
62   // Requires same rank.
63   if (sourceType.getRank() != resultType.getRank())
64     return false;
65 
66   // If cast is towards more static sizes along any dimension, don't fold.
67   for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
68     if (ShapedType::isDynamic(std::get<0>(t)) &&
69         !ShapedType::isDynamic(std::get<1>(t)))
70       return false;
71   }
72 
73   return true;
74 }
75 
76 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
77 /// that can be folded.
78 LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
79   bool folded = false;
80   for (OpOperand &operand : op->getOpOperands()) {
81     auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
82     if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
83       operand.set(castOp.getOperand());
84       folded = true;
85     }
86   }
87   return success(folded);
88 }
89 
90 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
91   if (inputs.size() != 1 || outputs.size() != 1)
92     return false;
93   Type a = inputs.front(), b = outputs.front();
94   auto aT = a.dyn_cast<TensorType>();
95   auto bT = b.dyn_cast<TensorType>();
96   if (!aT || !bT)
97     return false;
98 
99   if (aT.getElementType() != bT.getElementType())
100     return false;
101 
102   return succeeded(verifyCompatibleShape(aT, bT));
103 }
104 
105 /// Compute a TensorType that has the joined shape knowledge of the two
106 /// given TensorTypes. The element types need to match.
107 static TensorType joinShapes(TensorType one, TensorType two) {
108   assert(one.getElementType() == two.getElementType());
109 
110   if (!one.hasRank())
111     return two;
112   if (!two.hasRank())
113     return one;
114 
115   int64_t rank = one.getRank();
116   if (rank != two.getRank())
117     return {};
118 
119   SmallVector<int64_t, 4> join;
120   join.reserve(rank);
121   for (int64_t i = 0; i < rank; ++i) {
122     if (one.isDynamicDim(i)) {
123       join.push_back(two.getDimSize(i));
124       continue;
125     }
126     if (two.isDynamicDim(i)) {
127       join.push_back(one.getDimSize(i));
128       continue;
129     }
130     if (one.getDimSize(i) != two.getDimSize(i))
131       return {};
132     join.push_back(one.getDimSize(i));
133   }
134   return RankedTensorType::get(join, one.getElementType());
135 }
136 
137 namespace {
138 
139 /// Replaces chains of two tensor.cast operations by a single tensor.cast
140 /// operation if doing so does not remove runtime constraints.
141 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
142   using OpRewritePattern<CastOp>::OpRewritePattern;
143 
144   LogicalResult matchAndRewrite(CastOp tensorCast,
145                                 PatternRewriter &rewriter) const final {
146     auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
147 
148     if (!tensorCastOperand)
149       return failure();
150 
151     auto sourceType =
152         tensorCastOperand.getOperand().getType().cast<TensorType>();
153     auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
154     auto resultType = tensorCast.getType().cast<TensorType>();
155 
156     // We can remove the intermediate cast if joining all three produces the
157     // same result as just joining the source and result shapes.
158     auto firstJoin =
159         joinShapes(joinShapes(sourceType, intermediateType), resultType);
160 
161     // The join might not exist if the cast sequence would fail at runtime.
162     if (!firstJoin)
163       return failure();
164 
165     // The newJoin always exists if the above join exists, it might just contain
166     // less information. If so, we cannot drop the intermediate cast, as doing
167     // so would remove runtime checks.
168     auto newJoin = joinShapes(sourceType, resultType);
169     if (firstJoin != newJoin)
170       return failure();
171 
172     rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
173                                         tensorCastOperand.getOperand());
174     return success();
175   }
176 };
177 
178 } // namespace
179 
180 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
181                                          MLIRContext *context) {
182   results.add<ChainedTensorCast>(context);
183 }
184 
185 //===----------------------------------------------------------------------===//
186 // ExtractOp
187 //===----------------------------------------------------------------------===//
188 
189 static LogicalResult verify(ExtractOp op) {
190   // Verify the # indices match if we have a ranked type.
191   if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
192     if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
193       return op.emitOpError("incorrect number of indices for extract_element");
194 
195   return success();
196 }
197 
198 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
199   // The tensor operand must be a known constant.
200   Attribute tensor = operands.front();
201   if (!tensor)
202     return {};
203   // If this is a splat elements attribute, simply return the value. All of the
204   // elements of a splat attribute are the same.
205   if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
206     return splatTensor.getSplatValue();
207 
208   // Otherwise, collect the constant indices into the tensor.
209   SmallVector<uint64_t, 8> indices;
210   for (Attribute indice : llvm::drop_begin(operands, 1)) {
211     if (!indice || !indice.isa<IntegerAttr>())
212       return {};
213     indices.push_back(indice.cast<IntegerAttr>().getInt());
214   }
215 
216   // If this is an elements attribute, query the value at the given indices.
217   auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
218   if (elementsAttr && elementsAttr.isValidIndex(indices))
219     return elementsAttr.getValue(indices);
220   return {};
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // FromElementsOp
225 //===----------------------------------------------------------------------===//
226 
227 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
228                            Type elementType, ValueRange elements) {
229   Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
230                                         elementType);
231   result.addOperands(elements);
232   result.addTypes(resultTy);
233 }
234 
235 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
236                            ValueRange elements) {
237   assert(!elements.empty() && "expected at least one element");
238   build(builder, result, elements.front().getType(), elements);
239 }
240 
241 OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
242   if (!llvm::is_contained(operands, nullptr))
243     return DenseElementsAttr::get(getType(), operands);
244   return {};
245 }
246 
247 namespace {
248 
249 // Canonicalizes the pattern of the form
250 //
251 // %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
252 // %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
253 //
254 // to just %element.
255 struct ExtractElementFromTensorFromElements
256     : public OpRewritePattern<tensor::ExtractOp> {
257   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
258 
259   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
260                                 PatternRewriter &rewriter) const final {
261     if (extract.indices().size() != 1)
262       return failure();
263 
264     auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
265     if (tensorFromElements == nullptr)
266       return failure();
267 
268     APInt index;
269     if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
270       return failure();
271     // Prevent out of bounds accesses. This can happen in invalid code that will
272     // never execute.
273     if (tensorFromElements->getNumOperands() <= index.getZExtValue() ||
274         index.getSExtValue() < 0)
275       return failure();
276     rewriter.replaceOp(extract,
277                        tensorFromElements.getOperand(index.getZExtValue()));
278     return success();
279   }
280 };
281 
282 } // namespace
283 
284 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
285                                                  MLIRContext *context) {
286   results.add<ExtractElementFromTensorFromElements>(context);
287 }
288 
289 //===----------------------------------------------------------------------===//
290 // GenerateOp
291 //===----------------------------------------------------------------------===//
292 
293 static LogicalResult verify(GenerateOp op) {
294   // Ensure that the tensor type has as many dynamic dimensions as are specified
295   // by the operands.
296   RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
297   if (op.getNumOperands() != resultTy.getNumDynamicDims())
298     return op.emitError("must have as many index operands as dynamic extents "
299                         "in the result type");
300 
301   // Ensure that region arguments span the index space.
302   if (!llvm::all_of(op.body().getArgumentTypes(),
303                     [](Type ty) { return ty.isIndex(); }))
304     return op.emitError("all body arguments must be index");
305   if (op.body().getNumArguments() != resultTy.getRank())
306     return op.emitError("must have one body argument per input dimension");
307 
308   // Ensure that the region yields an element of the right type.
309   auto yieldOp =
310       llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
311   if (yieldOp.value().getType() != resultTy.getElementType())
312     return op.emitOpError(
313         "body must be terminated with a `yield` operation of the tensor "
314         "element type");
315 
316   return success();
317 }
318 
319 void GenerateOp::build(
320     OpBuilder &b, OperationState &result, Type resultTy,
321     ValueRange dynamicExtents,
322     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
323   build(b, result, resultTy, dynamicExtents);
324 
325   // Build and populate body.
326   OpBuilder::InsertionGuard guard(b);
327   Region *bodyRegion = result.regions.front().get();
328   auto rank = resultTy.cast<RankedTensorType>().getRank();
329   SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
330   Block *bodyBlock =
331       b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes);
332   bodyBuilder(b, result.location, bodyBlock->getArguments());
333 }
334 
335 namespace {
336 
337 /// Canonicalizes tensor.generate operations with a constant
338 /// operand into the equivalent operation with the operand expressed in the
339 /// result type, instead. We also insert a type cast to make sure that the
340 /// resulting IR is still well-typed.
341 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
342   using OpRewritePattern<GenerateOp>::OpRewritePattern;
343 
344   LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
345                                 PatternRewriter &rewriter) const final {
346     auto resultType =
347         tensorFromElements.getResult().getType().cast<RankedTensorType>();
348 
349     if (resultType.hasStaticShape())
350       return failure();
351 
352     SmallVector<Value, 4> newOperands;
353     SmallVector<int64_t, 4> newShape;
354     auto operandsIt = tensorFromElements.dynamicExtents().begin();
355 
356     for (int64_t dim : resultType.getShape()) {
357       if (dim != RankedTensorType::kDynamicSize) {
358         newShape.push_back(dim);
359         continue;
360       }
361       APInt index;
362       if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
363         newShape.push_back(RankedTensorType::kDynamicSize);
364         newOperands.push_back(*operandsIt++);
365         continue;
366       }
367       newShape.push_back(index.getSExtValue());
368       operandsIt++;
369     }
370 
371     if (newOperands.size() == tensorFromElements.dynamicExtents().size())
372       return failure();
373 
374     auto loc = tensorFromElements.getLoc();
375     auto newOp = rewriter.create<GenerateOp>(
376         loc, RankedTensorType::get(newShape, resultType.getElementType()),
377         newOperands);
378     rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
379                                 newOp.body().begin());
380     rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
381                                                 newOp);
382     return success();
383   }
384 };
385 
386 /// Canonicalizes the pattern of the form
387 ///
388 /// %tensor = tensor.generate %x {
389 ///   ^bb0(%arg0: index):  // no predecessors
390 ///   <computation>
391 ///   yield %1 : index
392 /// } : tensor<?xindex>
393 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
394 ///
395 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
396 /// tensor.generate operation has no side-effects.
397 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
398   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
399 
400   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
401                                 PatternRewriter &rewriter) const final {
402     auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>();
403     if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
404       return failure();
405 
406     BlockAndValueMapping mapping;
407     Block *body = tensorFromElements.getBody();
408     mapping.map(body->getArguments(), extract.indices());
409     for (auto &op : body->without_terminator())
410       rewriter.clone(op, mapping);
411 
412     auto yield = cast<YieldOp>(body->getTerminator());
413 
414     rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
415     return success();
416   }
417 };
418 
419 /// Canonicalizes the pattern of the form
420 ///
421 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
422 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
423 ///
424 /// to
425 ///
426 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
427 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
428   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
429 
430   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
431                                 PatternRewriter &rewriter) const final {
432     auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
433     if (!tensorCast)
434       return failure();
435 
436     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
437                                                    extract.indices());
438     return success();
439   }
440 };
441 
442 } // namespace
443 
444 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
445                                              MLIRContext *context) {
446   // TODO: Move extract patterns to tensor::ExtractOp.
447   results.add<ExtractFromTensorGenerate, ExtractFromTensorCast,
448               StaticTensorGenerate>(context);
449 }
450 
451 //===----------------------------------------------------------------------===//
452 // ReshapeOp
453 //===----------------------------------------------------------------------===//
454 
455 static int64_t GetNumElements(ShapedType type) {
456   int64_t numElements = 1;
457   for (auto dim : type.getShape())
458     numElements *= dim;
459   return numElements;
460 }
461 
462 static LogicalResult verify(ReshapeOp op) {
463   TensorType operandType = op.source().getType().cast<TensorType>();
464   TensorType resultType = op.result().getType().cast<TensorType>();
465 
466   if (operandType.getElementType() != resultType.getElementType())
467     return op.emitOpError("element types of source and destination tensor "
468                           "types should be the same");
469 
470   int64_t shapeSize =
471       op.shape().getType().cast<RankedTensorType>().getDimSize(0);
472   auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
473   auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
474 
475   if (resultRankedType) {
476     if (operandRankedType && resultRankedType.hasStaticShape() &&
477         operandRankedType.hasStaticShape()) {
478       if (GetNumElements(operandRankedType) != GetNumElements(resultRankedType))
479         return op.emitOpError("source and destination tensor should have the "
480                               "same number of elements");
481     }
482     if (shapeSize == TensorType::kDynamicSize)
483       return op.emitOpError("cannot use shape operand with dynamic length to "
484                             "reshape to statically-ranked tensor type");
485     if (shapeSize != resultRankedType.getRank())
486       return op.emitOpError(
487           "length of shape operand differs from the result's tensor rank");
488   }
489   return success();
490 }
491 
492 //===----------------------------------------------------------------------===//
493 // TableGen'd op method definitions
494 //===----------------------------------------------------------------------===//
495 
496 #define GET_OP_CLASSES
497 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
498