1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/IR/Tensor.h"
10 #include "mlir/IR/BlockAndValueMapping.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/IR/TypeUtilities.h"
15 #include "llvm/ADT/STLExtras.h"
16 
17 using namespace mlir;
18 using namespace mlir::tensor;
19 
20 //===----------------------------------------------------------------------===//
21 // CastOp
22 //===----------------------------------------------------------------------===//
23 
24 /// Determines whether tensor::CastOp casts to a more dynamic version of the
25 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
26 /// implement canonicalization patterns for ops in different dialects that may
27 /// consume the results of tensor.cast operations. Such foldable tensor.cast
28 /// operations are typically inserted as `subtensor` ops and are canonicalized,
29 /// to preserve the type compatibility of their uses.
30 ///
31 /// Returns true when all conditions are met:
32 /// 1. source and result are ranked tensors with same element type and rank.
33 /// 2. the tensor type has more static information than the result
34 ///
35 /// Example:
36 /// ```mlir
37 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
38 ///   %2 = consumer %1 ... : tensor<?x?xf32> ...
39 /// ```
40 ///
41 /// folds into:
42 ///
43 /// ```mlir
44 ///   %2 = consumer %0 ... : tensor<8x16xf32> ...
45 /// ```
46 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
47   if (!castOp)
48     return false;
49 
50   RankedTensorType sourceType =
51       castOp.source().getType().dyn_cast<RankedTensorType>();
52   RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
53 
54   // Requires RankedTensorType.
55   if (!sourceType || !resultType)
56     return false;
57 
58   // Requires same elemental type.
59   if (sourceType.getElementType() != resultType.getElementType())
60     return false;
61 
62   // Requires same rank.
63   if (sourceType.getRank() != resultType.getRank())
64     return false;
65 
66   // If cast is towards more static sizes along any dimension, don't fold.
67   for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
68     if (ShapedType::isDynamic(std::get<0>(t)) &&
69         !ShapedType::isDynamic(std::get<1>(t)))
70       return false;
71   }
72 
73   return true;
74 }
75 
76 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
77 /// that can be folded.
78 LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
79   bool folded = false;
80   for (OpOperand &operand : op->getOpOperands()) {
81     auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
82     if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
83       operand.set(castOp.getOperand());
84       folded = true;
85     }
86   }
87   return success(folded);
88 }
89 
90 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
91   if (inputs.size() != 1 || outputs.size() != 1)
92     return false;
93   Type a = inputs.front(), b = outputs.front();
94   auto aT = a.dyn_cast<TensorType>();
95   auto bT = b.dyn_cast<TensorType>();
96   if (!aT || !bT)
97     return false;
98 
99   if (aT.getElementType() != bT.getElementType())
100     return false;
101 
102   return succeeded(verifyCompatibleShape(aT, bT));
103 }
104 
105 /// Compute a TensorType that has the joined shape knowledge of the two
106 /// given TensorTypes. The element types need to match.
107 static TensorType joinShapes(TensorType one, TensorType two) {
108   assert(one.getElementType() == two.getElementType());
109 
110   if (!one.hasRank())
111     return two;
112   if (!two.hasRank())
113     return one;
114 
115   int64_t rank = one.getRank();
116   if (rank != two.getRank())
117     return {};
118 
119   SmallVector<int64_t, 4> join;
120   join.reserve(rank);
121   for (int64_t i = 0; i < rank; ++i) {
122     if (one.isDynamicDim(i)) {
123       join.push_back(two.getDimSize(i));
124       continue;
125     }
126     if (two.isDynamicDim(i)) {
127       join.push_back(one.getDimSize(i));
128       continue;
129     }
130     if (one.getDimSize(i) != two.getDimSize(i))
131       return {};
132     join.push_back(one.getDimSize(i));
133   }
134   return RankedTensorType::get(join, one.getElementType());
135 }
136 
137 namespace {
138 
139 /// Replaces chains of two tensor.cast operations by a single tensor.cast
140 /// operation if doing so does not remove runtime constraints.
141 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
142   using OpRewritePattern<CastOp>::OpRewritePattern;
143 
144   LogicalResult matchAndRewrite(CastOp tensorCast,
145                                 PatternRewriter &rewriter) const final {
146     auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
147 
148     if (!tensorCastOperand)
149       return failure();
150 
151     auto sourceType =
152         tensorCastOperand.getOperand().getType().cast<TensorType>();
153     auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
154     auto resultType = tensorCast.getType().cast<TensorType>();
155 
156     // We can remove the intermediate cast if joining all three produces the
157     // same result as just joining the source and result shapes.
158     auto firstJoin =
159         joinShapes(joinShapes(sourceType, intermediateType), resultType);
160 
161     // The join might not exist if the cast sequence would fail at runtime.
162     if (!firstJoin)
163       return failure();
164 
165     // The newJoin always exists if the above join exists, it might just contain
166     // less information. If so, we cannot drop the intermediate cast, as doing
167     // so would remove runtime checks.
168     auto newJoin = joinShapes(sourceType, resultType);
169     if (firstJoin != newJoin)
170       return failure();
171 
172     rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
173                                         tensorCastOperand.getOperand());
174     return success();
175   }
176 };
177 
178 } // namespace
179 
180 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
181                                          MLIRContext *context) {
182   results.add<ChainedTensorCast>(context);
183 }
184 
185 //===----------------------------------------------------------------------===//
186 // ExtractOp
187 //===----------------------------------------------------------------------===//
188 
189 static LogicalResult verify(ExtractOp op) {
190   // Verify the # indices match if we have a ranked type.
191   if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
192     if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
193       return op.emitOpError("incorrect number of indices for extract_element");
194 
195   return success();
196 }
197 
198 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
199   // The tensor operand must be a known constant.
200   Attribute tensor = operands.front();
201   if (!tensor)
202     return {};
203   // If this is a splat elements attribute, simply return the value. All of the
204   // elements of a splat attribute are the same.
205   if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
206     return splatTensor.getSplatValue();
207 
208   // Otherwise, collect the constant indices into the tensor.
209   SmallVector<uint64_t, 8> indices;
210   for (Attribute indice : llvm::drop_begin(operands, 1)) {
211     if (!indice || !indice.isa<IntegerAttr>())
212       return {};
213     indices.push_back(indice.cast<IntegerAttr>().getInt());
214   }
215 
216   // If this is an elements attribute, query the value at the given indices.
217   auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
218   if (elementsAttr && elementsAttr.isValidIndex(indices))
219     return elementsAttr.getValue(indices);
220   return {};
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // FromElementsOp
225 //===----------------------------------------------------------------------===//
226 
227 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
228                            Type elementType, ValueRange elements) {
229   Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
230                                         elementType);
231   result.addOperands(elements);
232   result.addTypes(resultTy);
233 }
234 
235 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
236                            ValueRange elements) {
237   assert(!elements.empty() && "expected at least one element");
238   build(builder, result, elements.front().getType(), elements);
239 }
240 
241 namespace {
242 
243 // Canonicalizes the pattern of the form
244 //
245 // %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
246 // %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
247 //
248 // to just %element.
249 struct ExtractElementFromTensorFromElements
250     : public OpRewritePattern<tensor::ExtractOp> {
251   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
252 
253   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
254                                 PatternRewriter &rewriter) const final {
255     if (extract.indices().size() != 1)
256       return failure();
257 
258     auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
259     if (tensorFromElements == nullptr)
260       return failure();
261 
262     APInt index;
263     if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
264       return failure();
265     // Prevent out of bounds accesses. This can happen in invalid code that will
266     // never execute.
267     if (tensorFromElements->getNumOperands() <= index.getZExtValue() ||
268         index.getSExtValue() < 0)
269       return failure();
270     rewriter.replaceOp(extract,
271                        tensorFromElements.getOperand(index.getZExtValue()));
272     return success();
273   }
274 };
275 
276 } // namespace
277 
278 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
279                                                  MLIRContext *context) {
280   results.add<ExtractElementFromTensorFromElements>(context);
281 }
282 
283 //===----------------------------------------------------------------------===//
284 // GenerateOp
285 //===----------------------------------------------------------------------===//
286 
287 static LogicalResult verify(GenerateOp op) {
288   // Ensure that the tensor type has as many dynamic dimensions as are specified
289   // by the operands.
290   RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
291   if (op.getNumOperands() != resultTy.getNumDynamicDims())
292     return op.emitError("must have as many index operands as dynamic extents "
293                         "in the result type");
294 
295   // Ensure that region arguments span the index space.
296   if (!llvm::all_of(op.body().getArgumentTypes(),
297                     [](Type ty) { return ty.isIndex(); }))
298     return op.emitError("all body arguments must be index");
299   if (op.body().getNumArguments() != resultTy.getRank())
300     return op.emitError("must have one body argument per input dimension");
301 
302   // Ensure that the region yields an element of the right type.
303   auto yieldOp =
304       llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
305   if (yieldOp.value().getType() != resultTy.getElementType())
306     return op.emitOpError(
307         "body must be terminated with a `yield` operation of the tensor "
308         "element type");
309 
310   return success();
311 }
312 
313 void GenerateOp::build(
314     OpBuilder &b, OperationState &result, Type resultTy,
315     ValueRange dynamicExtents,
316     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
317   build(b, result, resultTy, dynamicExtents);
318 
319   // Build and populate body.
320   OpBuilder::InsertionGuard guard(b);
321   Region *bodyRegion = result.regions.front().get();
322   auto rank = resultTy.cast<RankedTensorType>().getRank();
323   SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
324   Block *bodyBlock =
325       b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes);
326   bodyBuilder(b, result.location, bodyBlock->getArguments());
327 }
328 
329 namespace {
330 
331 /// Canonicalizes tensor.generate operations with a constant
332 /// operand into the equivalent operation with the operand expressed in the
333 /// result type, instead. We also insert a type cast to make sure that the
334 /// resulting IR is still well-typed.
335 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
336   using OpRewritePattern<GenerateOp>::OpRewritePattern;
337 
338   LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
339                                 PatternRewriter &rewriter) const final {
340     auto resultType =
341         tensorFromElements.getResult().getType().cast<RankedTensorType>();
342 
343     if (resultType.hasStaticShape())
344       return failure();
345 
346     SmallVector<Value, 4> newOperands;
347     SmallVector<int64_t, 4> newShape;
348     auto operandsIt = tensorFromElements.dynamicExtents().begin();
349 
350     for (int64_t dim : resultType.getShape()) {
351       if (dim != RankedTensorType::kDynamicSize) {
352         newShape.push_back(dim);
353         continue;
354       }
355       APInt index;
356       if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
357         newShape.push_back(RankedTensorType::kDynamicSize);
358         newOperands.push_back(*operandsIt++);
359         continue;
360       }
361       newShape.push_back(index.getSExtValue());
362       operandsIt++;
363     }
364 
365     if (newOperands.size() == tensorFromElements.dynamicExtents().size())
366       return failure();
367 
368     auto loc = tensorFromElements.getLoc();
369     auto newOp = rewriter.create<GenerateOp>(
370         loc, RankedTensorType::get(newShape, resultType.getElementType()),
371         newOperands);
372     rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
373                                 newOp.body().begin());
374     rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
375                                                 newOp);
376     return success();
377   }
378 };
379 
380 /// Canonicalizes the pattern of the form
381 ///
382 /// %tensor = tensor.generate %x {
383 ///   ^bb0(%arg0: index):  // no predecessors
384 ///   <computation>
385 ///   yield %1 : index
386 /// } : tensor<?xindex>
387 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
388 ///
389 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
390 /// tensor.generate operation has no side-effects.
391 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
392   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
393 
394   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
395                                 PatternRewriter &rewriter) const final {
396     auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>();
397     if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
398       return failure();
399 
400     BlockAndValueMapping mapping;
401     Block *body = tensorFromElements.getBody();
402     mapping.map(body->getArguments(), extract.indices());
403     for (auto &op : body->without_terminator())
404       rewriter.clone(op, mapping);
405 
406     auto yield = cast<YieldOp>(body->getTerminator());
407 
408     rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
409     return success();
410   }
411 };
412 
413 /// Canonicalizes the pattern of the form
414 ///
415 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
416 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
417 ///
418 /// to
419 ///
420 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
421 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
422   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
423 
424   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
425                                 PatternRewriter &rewriter) const final {
426     auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
427     if (!tensorCast)
428       return failure();
429 
430     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
431                                                    extract.indices());
432     return success();
433   }
434 };
435 
436 } // namespace
437 
438 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
439                                              MLIRContext *context) {
440   // TODO: Move extract patterns to tensor::ExtractOp.
441   results.add<ExtractFromTensorGenerate, ExtractFromTensorCast,
442               StaticTensorGenerate>(context);
443 }
444 
445 //===----------------------------------------------------------------------===//
446 // TableGen'd op method definitions
447 //===----------------------------------------------------------------------===//
448 
449 #define GET_OP_CLASSES
450 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
451