//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/STLExtras.h" using namespace mlir; using namespace mlir::tensor; //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// /// Determines whether tensor::CastOp casts to a more dynamic version of the /// source tensor. This is useful to fold a tensor.cast into a consuming op and /// implement canonicalization patterns for ops in different dialects that may /// consume the results of tensor.cast operations. Such foldable tensor.cast /// operations are typically inserted as `subtensor` ops and are canonicalized, /// to preserve the type compatibility of their uses. /// /// Returns true when all conditions are met: /// 1. source and result are ranked tensors with same element type and rank. /// 2. the tensor type has more static information than the result /// /// Example: /// ```mlir /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor /// %2 = consumer %1 ... : tensor ... /// ``` /// /// folds into: /// /// ```mlir /// %2 = consumer %0 ... : tensor<8x16xf32> ... /// ``` bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) { if (!castOp) return false; RankedTensorType sourceType = castOp.source().getType().dyn_cast(); RankedTensorType resultType = castOp.getType().dyn_cast(); // Requires RankedTensorType. if (!sourceType || !resultType) return false; // Requires same elemental type. if (sourceType.getElementType() != resultType.getElementType()) return false; // Requires same rank. if (sourceType.getRank() != resultType.getRank()) return false; // If cast is towards more static sizes along any dimension, don't fold. for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) { if (ShapedType::isDynamic(std::get<0>(t)) && !ShapedType::isDynamic(std::get<1>(t))) return false; } return true; } bool CastOp::areCastCompatible(Type a, Type b) { auto aT = a.dyn_cast(); auto bT = b.dyn_cast(); if (!aT || !bT) return false; if (aT.getElementType() != bT.getElementType()) return false; return succeeded(verifyCompatibleShape(aT, bT)); } OpFoldResult CastOp::fold(ArrayRef operands) { return impl::foldCastOp(*this); } /// Compute a TensorType that has the joined shape knowledge of the two /// given TensorTypes. The element types need to match. static TensorType joinShapes(TensorType one, TensorType two) { assert(one.getElementType() == two.getElementType()); if (!one.hasRank()) return two; if (!two.hasRank()) return one; int64_t rank = one.getRank(); if (rank != two.getRank()) return {}; SmallVector join; join.reserve(rank); for (int64_t i = 0; i < rank; ++i) { if (one.isDynamicDim(i)) { join.push_back(two.getDimSize(i)); continue; } if (two.isDynamicDim(i)) { join.push_back(one.getDimSize(i)); continue; } if (one.getDimSize(i) != two.getDimSize(i)) return {}; join.push_back(one.getDimSize(i)); } return RankedTensorType::get(join, one.getElementType()); } namespace { /// Replaces chains of two tensor.cast operations by a single tensor.cast /// operation if doing so does not remove runtime constraints. struct ChainedTensorCast : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CastOp tensorCast, PatternRewriter &rewriter) const final { auto tensorCastOperand = tensorCast.getOperand().getDefiningOp(); if (!tensorCastOperand) return failure(); auto sourceType = tensorCastOperand.getOperand().getType().cast(); auto intermediateType = tensorCastOperand.getType().cast(); auto resultType = tensorCast.getType().cast(); // We can remove the intermediate cast if joining all three produces the // same result as just joining the source and result shapes. auto firstJoin = joinShapes(joinShapes(sourceType, intermediateType), resultType); // The join might not exist if the cast sequence would fail at runtime. if (!firstJoin) return failure(); // The newJoin always exists if the above join exists, it might just contain // less information. If so, we cannot drop the intermediate cast, as doing // so would remove runtime checks. auto newJoin = joinShapes(sourceType, resultType); if (firstJoin != newJoin) return failure(); rewriter.replaceOpWithNewOp(tensorCast, resultType, tensorCastOperand.getOperand()); return success(); } }; } // namespace void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } //===----------------------------------------------------------------------===// // ExtractOp //===----------------------------------------------------------------------===// static LogicalResult verify(ExtractOp op) { // Verify the # indices match if we have a ranked type. if (auto tensorType = op.tensor().getType().dyn_cast()) if (tensorType.getRank() != static_cast(op.indices().size())) return op.emitOpError("incorrect number of indices for extract_element"); return success(); } OpFoldResult ExtractOp::fold(ArrayRef operands) { // The tensor operand must be a known constant. Attribute tensor = operands.front(); if (!tensor) return {}; // If this is a splat elements attribute, simply return the value. All of the // elements of a splat attribute are the same. if (auto splatTensor = tensor.dyn_cast()) return splatTensor.getSplatValue(); // Otherwise, collect the constant indices into the tensor. SmallVector indices; for (Attribute indice : llvm::drop_begin(operands, 1)) { if (!indice || !indice.isa()) return {}; indices.push_back(indice.cast().getInt()); } // If this is an elements attribute, query the value at the given indices. auto elementsAttr = tensor.dyn_cast(); if (elementsAttr && elementsAttr.isValidIndex(indices)) return elementsAttr.getValue(indices); return {}; } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"