1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/IR/Tensor.h" 10 #include "mlir/IR/BlockAndValueMapping.h" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/Matchers.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/IR/TypeUtilities.h" 15 #include "llvm/ADT/STLExtras.h" 16 17 using namespace mlir; 18 using namespace mlir::tensor; 19 20 //===----------------------------------------------------------------------===// 21 // CastOp 22 //===----------------------------------------------------------------------===// 23 24 /// Determines whether tensor::CastOp casts to a more dynamic version of the 25 /// source tensor. This is useful to fold a tensor.cast into a consuming op and 26 /// implement canonicalization patterns for ops in different dialects that may 27 /// consume the results of tensor.cast operations. Such foldable tensor.cast 28 /// operations are typically inserted as `subtensor` ops and are canonicalized, 29 /// to preserve the type compatibility of their uses. 30 /// 31 /// Returns true when all conditions are met: 32 /// 1. source and result are ranked tensors with same element type and rank. 33 /// 2. the tensor type has more static information than the result 34 /// 35 /// Example: 36 /// ```mlir 37 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32> 38 /// %2 = consumer %1 ... : tensor<?x?xf32> ... 39 /// ``` 40 /// 41 /// folds into: 42 /// 43 /// ```mlir 44 /// %2 = consumer %0 ... : tensor<8x16xf32> ... 45 /// ``` 46 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) { 47 if (!castOp) 48 return false; 49 50 RankedTensorType sourceType = 51 castOp.source().getType().dyn_cast<RankedTensorType>(); 52 RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>(); 53 54 // Requires RankedTensorType. 55 if (!sourceType || !resultType) 56 return false; 57 58 // Requires same elemental type. 59 if (sourceType.getElementType() != resultType.getElementType()) 60 return false; 61 62 // Requires same rank. 63 if (sourceType.getRank() != resultType.getRank()) 64 return false; 65 66 // If cast is towards more static sizes along any dimension, don't fold. 67 for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) { 68 if (ShapedType::isDynamic(std::get<0>(t)) && 69 !ShapedType::isDynamic(std::get<1>(t))) 70 return false; 71 } 72 73 return true; 74 } 75 76 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 77 if (inputs.size() != 1 || outputs.size() != 1) 78 return false; 79 Type a = inputs.front(), b = outputs.front(); 80 auto aT = a.dyn_cast<TensorType>(); 81 auto bT = b.dyn_cast<TensorType>(); 82 if (!aT || !bT) 83 return false; 84 85 if (aT.getElementType() != bT.getElementType()) 86 return false; 87 88 return succeeded(verifyCompatibleShape(aT, bT)); 89 } 90 91 /// Compute a TensorType that has the joined shape knowledge of the two 92 /// given TensorTypes. The element types need to match. 93 static TensorType joinShapes(TensorType one, TensorType two) { 94 assert(one.getElementType() == two.getElementType()); 95 96 if (!one.hasRank()) 97 return two; 98 if (!two.hasRank()) 99 return one; 100 101 int64_t rank = one.getRank(); 102 if (rank != two.getRank()) 103 return {}; 104 105 SmallVector<int64_t, 4> join; 106 join.reserve(rank); 107 for (int64_t i = 0; i < rank; ++i) { 108 if (one.isDynamicDim(i)) { 109 join.push_back(two.getDimSize(i)); 110 continue; 111 } 112 if (two.isDynamicDim(i)) { 113 join.push_back(one.getDimSize(i)); 114 continue; 115 } 116 if (one.getDimSize(i) != two.getDimSize(i)) 117 return {}; 118 join.push_back(one.getDimSize(i)); 119 } 120 return RankedTensorType::get(join, one.getElementType()); 121 } 122 123 namespace { 124 125 /// Replaces chains of two tensor.cast operations by a single tensor.cast 126 /// operation if doing so does not remove runtime constraints. 127 struct ChainedTensorCast : public OpRewritePattern<CastOp> { 128 using OpRewritePattern<CastOp>::OpRewritePattern; 129 130 LogicalResult matchAndRewrite(CastOp tensorCast, 131 PatternRewriter &rewriter) const final { 132 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>(); 133 134 if (!tensorCastOperand) 135 return failure(); 136 137 auto sourceType = 138 tensorCastOperand.getOperand().getType().cast<TensorType>(); 139 auto intermediateType = tensorCastOperand.getType().cast<TensorType>(); 140 auto resultType = tensorCast.getType().cast<TensorType>(); 141 142 // We can remove the intermediate cast if joining all three produces the 143 // same result as just joining the source and result shapes. 144 auto firstJoin = 145 joinShapes(joinShapes(sourceType, intermediateType), resultType); 146 147 // The join might not exist if the cast sequence would fail at runtime. 148 if (!firstJoin) 149 return failure(); 150 151 // The newJoin always exists if the above join exists, it might just contain 152 // less information. If so, we cannot drop the intermediate cast, as doing 153 // so would remove runtime checks. 154 auto newJoin = joinShapes(sourceType, resultType); 155 if (firstJoin != newJoin) 156 return failure(); 157 158 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType, 159 tensorCastOperand.getOperand()); 160 return success(); 161 } 162 }; 163 164 } // namespace 165 166 void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 167 MLIRContext *context) { 168 results.insert<ChainedTensorCast>(context); 169 } 170 171 //===----------------------------------------------------------------------===// 172 // ExtractOp 173 //===----------------------------------------------------------------------===// 174 175 static LogicalResult verify(ExtractOp op) { 176 // Verify the # indices match if we have a ranked type. 177 if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>()) 178 if (tensorType.getRank() != static_cast<int64_t>(op.indices().size())) 179 return op.emitOpError("incorrect number of indices for extract_element"); 180 181 return success(); 182 } 183 184 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) { 185 // The tensor operand must be a known constant. 186 Attribute tensor = operands.front(); 187 if (!tensor) 188 return {}; 189 // If this is a splat elements attribute, simply return the value. All of the 190 // elements of a splat attribute are the same. 191 if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>()) 192 return splatTensor.getSplatValue(); 193 194 // Otherwise, collect the constant indices into the tensor. 195 SmallVector<uint64_t, 8> indices; 196 for (Attribute indice : llvm::drop_begin(operands, 1)) { 197 if (!indice || !indice.isa<IntegerAttr>()) 198 return {}; 199 indices.push_back(indice.cast<IntegerAttr>().getInt()); 200 } 201 202 // If this is an elements attribute, query the value at the given indices. 203 auto elementsAttr = tensor.dyn_cast<ElementsAttr>(); 204 if (elementsAttr && elementsAttr.isValidIndex(indices)) 205 return elementsAttr.getValue(indices); 206 return {}; 207 } 208 209 //===----------------------------------------------------------------------===// 210 // FromElementsOp 211 //===----------------------------------------------------------------------===// 212 213 void FromElementsOp::build(OpBuilder &builder, OperationState &result, 214 Type elementType, ValueRange elements) { 215 Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())}, 216 elementType); 217 result.addOperands(elements); 218 result.addTypes(resultTy); 219 } 220 221 void FromElementsOp::build(OpBuilder &builder, OperationState &result, 222 ValueRange elements) { 223 assert(!elements.empty() && "expected at least one element"); 224 build(builder, result, elements.front().getType(), elements); 225 } 226 227 namespace { 228 229 // Canonicalizes the pattern of the form 230 // 231 // %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> 232 // %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32> 233 // 234 // to just %element. 235 struct ExtractElementFromTensorFromElements 236 : public OpRewritePattern<tensor::ExtractOp> { 237 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 238 239 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 240 PatternRewriter &rewriter) const final { 241 if (extract.indices().size() != 1) 242 return failure(); 243 244 auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>(); 245 if (tensorFromElements == nullptr) 246 return failure(); 247 248 APInt index; 249 if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) 250 return failure(); 251 rewriter.replaceOp(extract, 252 tensorFromElements.getOperand(index.getZExtValue())); 253 return success(); 254 } 255 }; 256 257 } // namespace 258 259 void FromElementsOp::getCanonicalizationPatterns( 260 OwningRewritePatternList &results, MLIRContext *context) { 261 results.insert<ExtractElementFromTensorFromElements>(context); 262 } 263 264 //===----------------------------------------------------------------------===// 265 // GenerateOp 266 //===----------------------------------------------------------------------===// 267 268 static LogicalResult verify(GenerateOp op) { 269 // Ensure that the tensor type has as many dynamic dimensions as are specified 270 // by the operands. 271 RankedTensorType resultTy = op.getType().cast<RankedTensorType>(); 272 if (op.getNumOperands() != resultTy.getNumDynamicDims()) 273 return op.emitError("must have as many index operands as dynamic extents " 274 "in the result type"); 275 276 // Ensure that region arguments span the index space. 277 if (!llvm::all_of(op.body().getArgumentTypes(), 278 [](Type ty) { return ty.isIndex(); })) 279 return op.emitError("all body arguments must be index"); 280 if (op.body().getNumArguments() != resultTy.getRank()) 281 return op.emitError("must have one body argument per input dimension"); 282 283 // Ensure that the region yields an element of the right type. 284 auto yieldOp = 285 llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator()); 286 if (yieldOp.value().getType() != resultTy.getElementType()) 287 return op.emitOpError( 288 "body must be terminated with a `yield` operation of the tensor " 289 "element type"); 290 291 return success(); 292 } 293 294 void GenerateOp::build( 295 OpBuilder &b, OperationState &result, Type resultTy, 296 ValueRange dynamicExtents, 297 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) { 298 build(b, result, resultTy, dynamicExtents); 299 300 // Build and populate body. 301 OpBuilder::InsertionGuard guard(b); 302 Region *bodyRegion = result.regions.front().get(); 303 auto rank = resultTy.cast<RankedTensorType>().getRank(); 304 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType()); 305 Block *bodyBlock = 306 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes); 307 bodyBuilder(b, result.location, bodyBlock->getArguments()); 308 } 309 310 namespace { 311 312 /// Canonicalizes tensor.generate operations with a constant 313 /// operand into the equivalent operation with the operand expressed in the 314 /// result type, instead. We also insert a type cast to make sure that the 315 /// resulting IR is still well-typed. 316 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> { 317 using OpRewritePattern<GenerateOp>::OpRewritePattern; 318 319 LogicalResult matchAndRewrite(GenerateOp tensorFromElements, 320 PatternRewriter &rewriter) const final { 321 auto resultType = 322 tensorFromElements.getResult().getType().cast<RankedTensorType>(); 323 324 if (resultType.hasStaticShape()) 325 return failure(); 326 327 SmallVector<Value, 4> newOperands; 328 SmallVector<int64_t, 4> newShape; 329 auto operandsIt = tensorFromElements.dynamicExtents().begin(); 330 331 for (int64_t dim : resultType.getShape()) { 332 if (dim != RankedTensorType::kDynamicSize) { 333 newShape.push_back(dim); 334 continue; 335 } 336 APInt index; 337 if (!matchPattern(*operandsIt, m_ConstantInt(&index))) { 338 newShape.push_back(RankedTensorType::kDynamicSize); 339 newOperands.push_back(*operandsIt++); 340 continue; 341 } 342 newShape.push_back(index.getSExtValue()); 343 operandsIt++; 344 } 345 346 if (newOperands.size() == tensorFromElements.dynamicExtents().size()) 347 return failure(); 348 349 auto loc = tensorFromElements.getLoc(); 350 auto newOp = rewriter.create<GenerateOp>( 351 loc, RankedTensorType::get(newShape, resultType.getElementType()), 352 newOperands); 353 rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(), 354 newOp.body().begin()); 355 rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType, 356 newOp); 357 return success(); 358 } 359 }; 360 361 /// Canonicalizes the pattern of the form 362 /// 363 /// %tensor = tensor.generate %x { 364 /// ^bb0(%arg0: index): // no predecessors 365 /// <computation> 366 /// yield %1 : index 367 /// } : tensor<?xindex> 368 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32> 369 /// 370 /// to just <computation> with %arg0 replaced by %c0. We only do this if the 371 /// tensor.generate operation has no side-effects. 372 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> { 373 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 374 375 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 376 PatternRewriter &rewriter) const final { 377 auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>(); 378 if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements)) 379 return failure(); 380 381 BlockAndValueMapping mapping; 382 Block *body = tensorFromElements.getBody(); 383 mapping.map(body->getArguments(), extract.indices()); 384 for (auto &op : body->without_terminator()) 385 rewriter.clone(op, mapping); 386 387 auto yield = cast<YieldOp>(body->getTerminator()); 388 389 rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value())); 390 return success(); 391 } 392 }; 393 394 /// Canonicalizes the pattern of the form 395 /// 396 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32> 397 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32> 398 /// 399 /// to 400 /// 401 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32> 402 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> { 403 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 404 405 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 406 PatternRewriter &rewriter) const final { 407 auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>(); 408 if (!tensorCast) 409 return failure(); 410 411 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(), 412 extract.indices()); 413 return success(); 414 } 415 }; 416 417 } // namespace 418 419 void GenerateOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 420 MLIRContext *context) { 421 // TODO: Move extract patterns to tensor::ExtractOp. 422 results.insert<ExtractFromTensorGenerate, ExtractFromTensorCast, 423 StaticTensorGenerate>(context); 424 } 425 426 //===----------------------------------------------------------------------===// 427 // TableGen'd op method definitions 428 //===----------------------------------------------------------------------===// 429 430 #define GET_OP_CLASSES 431 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc" 432