1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/IR/Tensor.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/PatternMatch.h" 12 #include "mlir/IR/TypeUtilities.h" 13 #include "llvm/ADT/STLExtras.h" 14 15 using namespace mlir; 16 using namespace mlir::tensor; 17 18 //===----------------------------------------------------------------------===// 19 // CastOp 20 //===----------------------------------------------------------------------===// 21 22 /// Determines whether tensor::CastOp casts to a more dynamic version of the 23 /// source tensor. This is useful to fold a tensor.cast into a consuming op and 24 /// implement canonicalization patterns for ops in different dialects that may 25 /// consume the results of tensor.cast operations. Such foldable tensor.cast 26 /// operations are typically inserted as `subtensor` ops and are canonicalized, 27 /// to preserve the type compatibility of their uses. 28 /// 29 /// Returns true when all conditions are met: 30 /// 1. source and result are ranked tensors with same element type and rank. 31 /// 2. the tensor type has more static information than the result 32 /// 33 /// Example: 34 /// ```mlir 35 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32> 36 /// %2 = consumer %1 ... : tensor<?x?xf32> ... 37 /// ``` 38 /// 39 /// folds into: 40 /// 41 /// ```mlir 42 /// %2 = consumer %0 ... : tensor<8x16xf32> ... 43 /// ``` 44 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) { 45 if (!castOp) 46 return false; 47 48 RankedTensorType sourceType = 49 castOp.source().getType().dyn_cast<RankedTensorType>(); 50 RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>(); 51 52 // Requires RankedTensorType. 53 if (!sourceType || !resultType) 54 return false; 55 56 // Requires same elemental type. 57 if (sourceType.getElementType() != resultType.getElementType()) 58 return false; 59 60 // Requires same rank. 61 if (sourceType.getRank() != resultType.getRank()) 62 return false; 63 64 // If cast is towards more static sizes along any dimension, don't fold. 65 for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) { 66 if (ShapedType::isDynamic(std::get<0>(t)) && 67 !ShapedType::isDynamic(std::get<1>(t))) 68 return false; 69 } 70 71 return true; 72 } 73 74 bool CastOp::areCastCompatible(Type a, Type b) { 75 auto aT = a.dyn_cast<TensorType>(); 76 auto bT = b.dyn_cast<TensorType>(); 77 if (!aT || !bT) 78 return false; 79 80 if (aT.getElementType() != bT.getElementType()) 81 return false; 82 83 return succeeded(verifyCompatibleShape(aT, bT)); 84 } 85 86 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 87 return impl::foldCastOp(*this); 88 } 89 90 /// Compute a TensorType that has the joined shape knowledge of the two 91 /// given TensorTypes. The element types need to match. 92 static TensorType joinShapes(TensorType one, TensorType two) { 93 assert(one.getElementType() == two.getElementType()); 94 95 if (!one.hasRank()) 96 return two; 97 if (!two.hasRank()) 98 return one; 99 100 int64_t rank = one.getRank(); 101 if (rank != two.getRank()) 102 return {}; 103 104 SmallVector<int64_t, 4> join; 105 join.reserve(rank); 106 for (int64_t i = 0; i < rank; ++i) { 107 if (one.isDynamicDim(i)) { 108 join.push_back(two.getDimSize(i)); 109 continue; 110 } 111 if (two.isDynamicDim(i)) { 112 join.push_back(one.getDimSize(i)); 113 continue; 114 } 115 if (one.getDimSize(i) != two.getDimSize(i)) 116 return {}; 117 join.push_back(one.getDimSize(i)); 118 } 119 return RankedTensorType::get(join, one.getElementType()); 120 } 121 122 namespace { 123 124 /// Replaces chains of two tensor.cast operations by a single tensor.cast 125 /// operation if doing so does not remove runtime constraints. 126 struct ChainedTensorCast : public OpRewritePattern<CastOp> { 127 using OpRewritePattern<CastOp>::OpRewritePattern; 128 129 LogicalResult matchAndRewrite(CastOp tensorCast, 130 PatternRewriter &rewriter) const final { 131 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>(); 132 133 if (!tensorCastOperand) 134 return failure(); 135 136 auto sourceType = 137 tensorCastOperand.getOperand().getType().cast<TensorType>(); 138 auto intermediateType = tensorCastOperand.getType().cast<TensorType>(); 139 auto resultType = tensorCast.getType().cast<TensorType>(); 140 141 // We can remove the intermediate cast if joining all three produces the 142 // same result as just joining the source and result shapes. 143 auto firstJoin = 144 joinShapes(joinShapes(sourceType, intermediateType), resultType); 145 146 // The join might not exist if the cast sequence would fail at runtime. 147 if (!firstJoin) 148 return failure(); 149 150 // The newJoin always exists if the above join exists, it might just contain 151 // less information. If so, we cannot drop the intermediate cast, as doing 152 // so would remove runtime checks. 153 auto newJoin = joinShapes(sourceType, resultType); 154 if (firstJoin != newJoin) 155 return failure(); 156 157 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType, 158 tensorCastOperand.getOperand()); 159 return success(); 160 } 161 }; 162 163 } // namespace 164 165 void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 166 MLIRContext *context) { 167 results.insert<ChainedTensorCast>(context); 168 } 169 170 //===----------------------------------------------------------------------===// 171 // ExtractOp 172 //===----------------------------------------------------------------------===// 173 174 static LogicalResult verify(ExtractOp op) { 175 // Verify the # indices match if we have a ranked type. 176 if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>()) 177 if (tensorType.getRank() != static_cast<int64_t>(op.indices().size())) 178 return op.emitOpError("incorrect number of indices for extract_element"); 179 180 return success(); 181 } 182 183 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) { 184 // The tensor operand must be a known constant. 185 Attribute tensor = operands.front(); 186 if (!tensor) 187 return {}; 188 // If this is a splat elements attribute, simply return the value. All of the 189 // elements of a splat attribute are the same. 190 if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>()) 191 return splatTensor.getSplatValue(); 192 193 // Otherwise, collect the constant indices into the tensor. 194 SmallVector<uint64_t, 8> indices; 195 for (Attribute indice : llvm::drop_begin(operands, 1)) { 196 if (!indice || !indice.isa<IntegerAttr>()) 197 return {}; 198 indices.push_back(indice.cast<IntegerAttr>().getInt()); 199 } 200 201 // If this is an elements attribute, query the value at the given indices. 202 auto elementsAttr = tensor.dyn_cast<ElementsAttr>(); 203 if (elementsAttr && elementsAttr.isValidIndex(indices)) 204 return elementsAttr.getValue(indices); 205 return {}; 206 } 207 208 //===----------------------------------------------------------------------===// 209 // TableGen'd op method definitions 210 //===----------------------------------------------------------------------===// 211 212 #define GET_OP_CLASSES 213 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc" 214