1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/IR/Tensor.h" 10 #include "mlir/IR/BlockAndValueMapping.h" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/Matchers.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/IR/TypeUtilities.h" 15 #include "llvm/ADT/STLExtras.h" 16 17 using namespace mlir; 18 using namespace mlir::tensor; 19 20 //===----------------------------------------------------------------------===// 21 // CastOp 22 //===----------------------------------------------------------------------===// 23 24 /// Determines whether tensor::CastOp casts to a more dynamic version of the 25 /// source tensor. This is useful to fold a tensor.cast into a consuming op and 26 /// implement canonicalization patterns for ops in different dialects that may 27 /// consume the results of tensor.cast operations. Such foldable tensor.cast 28 /// operations are typically inserted as `subtensor` ops and are canonicalized, 29 /// to preserve the type compatibility of their uses. 30 /// 31 /// Returns true when all conditions are met: 32 /// 1. source and result are ranked tensors with same element type and rank. 33 /// 2. the tensor type has more static information than the result 34 /// 35 /// Example: 36 /// ```mlir 37 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32> 38 /// %2 = consumer %1 ... : tensor<?x?xf32> ... 39 /// ``` 40 /// 41 /// folds into: 42 /// 43 /// ```mlir 44 /// %2 = consumer %0 ... : tensor<8x16xf32> ... 45 /// ``` 46 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) { 47 if (!castOp) 48 return false; 49 50 RankedTensorType sourceType = 51 castOp.source().getType().dyn_cast<RankedTensorType>(); 52 RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>(); 53 54 // Requires RankedTensorType. 55 if (!sourceType || !resultType) 56 return false; 57 58 // Requires same elemental type. 59 if (sourceType.getElementType() != resultType.getElementType()) 60 return false; 61 62 // Requires same rank. 63 if (sourceType.getRank() != resultType.getRank()) 64 return false; 65 66 // If cast is towards more static sizes along any dimension, don't fold. 67 for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) { 68 if (ShapedType::isDynamic(std::get<0>(t)) && 69 !ShapedType::isDynamic(std::get<1>(t))) 70 return false; 71 } 72 73 return true; 74 } 75 76 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp 77 /// that can be folded. 78 LogicalResult mlir::tensor::foldTensorCast(Operation *op) { 79 bool folded = false; 80 for (OpOperand &operand : op->getOpOperands()) { 81 auto castOp = operand.get().getDefiningOp<tensor::CastOp>(); 82 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) { 83 operand.set(castOp.getOperand()); 84 folded = true; 85 } 86 } 87 return success(folded); 88 } 89 90 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 91 if (inputs.size() != 1 || outputs.size() != 1) 92 return false; 93 Type a = inputs.front(), b = outputs.front(); 94 auto aT = a.dyn_cast<TensorType>(); 95 auto bT = b.dyn_cast<TensorType>(); 96 if (!aT || !bT) 97 return false; 98 99 if (aT.getElementType() != bT.getElementType()) 100 return false; 101 102 return succeeded(verifyCompatibleShape(aT, bT)); 103 } 104 105 /// Compute a TensorType that has the joined shape knowledge of the two 106 /// given TensorTypes. The element types need to match. 107 static TensorType joinShapes(TensorType one, TensorType two) { 108 assert(one.getElementType() == two.getElementType()); 109 110 if (!one.hasRank()) 111 return two; 112 if (!two.hasRank()) 113 return one; 114 115 int64_t rank = one.getRank(); 116 if (rank != two.getRank()) 117 return {}; 118 119 SmallVector<int64_t, 4> join; 120 join.reserve(rank); 121 for (int64_t i = 0; i < rank; ++i) { 122 if (one.isDynamicDim(i)) { 123 join.push_back(two.getDimSize(i)); 124 continue; 125 } 126 if (two.isDynamicDim(i)) { 127 join.push_back(one.getDimSize(i)); 128 continue; 129 } 130 if (one.getDimSize(i) != two.getDimSize(i)) 131 return {}; 132 join.push_back(one.getDimSize(i)); 133 } 134 return RankedTensorType::get(join, one.getElementType()); 135 } 136 137 namespace { 138 139 /// Replaces chains of two tensor.cast operations by a single tensor.cast 140 /// operation if doing so does not remove runtime constraints. 141 struct ChainedTensorCast : public OpRewritePattern<CastOp> { 142 using OpRewritePattern<CastOp>::OpRewritePattern; 143 144 LogicalResult matchAndRewrite(CastOp tensorCast, 145 PatternRewriter &rewriter) const final { 146 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>(); 147 148 if (!tensorCastOperand) 149 return failure(); 150 151 auto sourceType = 152 tensorCastOperand.getOperand().getType().cast<TensorType>(); 153 auto intermediateType = tensorCastOperand.getType().cast<TensorType>(); 154 auto resultType = tensorCast.getType().cast<TensorType>(); 155 156 // We can remove the intermediate cast if joining all three produces the 157 // same result as just joining the source and result shapes. 158 auto firstJoin = 159 joinShapes(joinShapes(sourceType, intermediateType), resultType); 160 161 // The join might not exist if the cast sequence would fail at runtime. 162 if (!firstJoin) 163 return failure(); 164 165 // The newJoin always exists if the above join exists, it might just contain 166 // less information. If so, we cannot drop the intermediate cast, as doing 167 // so would remove runtime checks. 168 auto newJoin = joinShapes(sourceType, resultType); 169 if (firstJoin != newJoin) 170 return failure(); 171 172 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType, 173 tensorCastOperand.getOperand()); 174 return success(); 175 } 176 }; 177 178 } // namespace 179 180 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results, 181 MLIRContext *context) { 182 results.add<ChainedTensorCast>(context); 183 } 184 185 //===----------------------------------------------------------------------===// 186 // ExtractOp 187 //===----------------------------------------------------------------------===// 188 189 static LogicalResult verify(ExtractOp op) { 190 // Verify the # indices match if we have a ranked type. 191 if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>()) 192 if (tensorType.getRank() != static_cast<int64_t>(op.indices().size())) 193 return op.emitOpError("incorrect number of indices for extract_element"); 194 195 return success(); 196 } 197 198 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) { 199 // The tensor operand must be a known constant. 200 Attribute tensor = operands.front(); 201 if (!tensor) 202 return {}; 203 // If this is a splat elements attribute, simply return the value. All of the 204 // elements of a splat attribute are the same. 205 if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>()) 206 return splatTensor.getSplatValue(); 207 208 // Otherwise, collect the constant indices into the tensor. 209 SmallVector<uint64_t, 8> indices; 210 for (Attribute indice : llvm::drop_begin(operands, 1)) { 211 if (!indice || !indice.isa<IntegerAttr>()) 212 return {}; 213 indices.push_back(indice.cast<IntegerAttr>().getInt()); 214 } 215 216 // If this is an elements attribute, query the value at the given indices. 217 auto elementsAttr = tensor.dyn_cast<ElementsAttr>(); 218 if (elementsAttr && elementsAttr.isValidIndex(indices)) 219 return elementsAttr.getValue(indices); 220 return {}; 221 } 222 223 //===----------------------------------------------------------------------===// 224 // FromElementsOp 225 //===----------------------------------------------------------------------===// 226 227 void FromElementsOp::build(OpBuilder &builder, OperationState &result, 228 Type elementType, ValueRange elements) { 229 Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())}, 230 elementType); 231 result.addOperands(elements); 232 result.addTypes(resultTy); 233 } 234 235 void FromElementsOp::build(OpBuilder &builder, OperationState &result, 236 ValueRange elements) { 237 assert(!elements.empty() && "expected at least one element"); 238 build(builder, result, elements.front().getType(), elements); 239 } 240 241 OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) { 242 if (!llvm::is_contained(operands, nullptr)) 243 return DenseElementsAttr::get(getType(), operands); 244 return {}; 245 } 246 247 namespace { 248 249 // Canonicalizes the pattern of the form 250 // 251 // %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> 252 // %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32> 253 // 254 // to just %element. 255 struct ExtractElementFromTensorFromElements 256 : public OpRewritePattern<tensor::ExtractOp> { 257 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 258 259 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 260 PatternRewriter &rewriter) const final { 261 if (extract.indices().size() != 1) 262 return failure(); 263 264 auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>(); 265 if (tensorFromElements == nullptr) 266 return failure(); 267 268 APInt index; 269 if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) 270 return failure(); 271 // Prevent out of bounds accesses. This can happen in invalid code that will 272 // never execute. 273 if (tensorFromElements->getNumOperands() <= index.getZExtValue() || 274 index.getSExtValue() < 0) 275 return failure(); 276 rewriter.replaceOp(extract, 277 tensorFromElements.getOperand(index.getZExtValue())); 278 return success(); 279 } 280 }; 281 282 } // namespace 283 284 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, 285 MLIRContext *context) { 286 results.add<ExtractElementFromTensorFromElements>(context); 287 } 288 289 //===----------------------------------------------------------------------===// 290 // InsertOp 291 //===----------------------------------------------------------------------===// 292 293 static LogicalResult verify(InsertOp op) { 294 // Verify the # indices match if we have a ranked type. 295 if (auto destType = op.dest().getType().dyn_cast<RankedTensorType>()) 296 if (destType.getRank() != static_cast<int64_t>(op.indices().size())) 297 return op.emitOpError("incorrect number of indices"); 298 return success(); 299 } 300 301 OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) { 302 Attribute scalar = operands[0]; 303 Attribute dest = operands[1]; 304 if (scalar && dest) 305 if (auto splatDest = dest.dyn_cast<SplatElementsAttr>()) 306 if (scalar == splatDest.getSplatValue()) 307 return dest; 308 return {}; 309 } 310 311 //===----------------------------------------------------------------------===// 312 // GenerateOp 313 //===----------------------------------------------------------------------===// 314 315 static LogicalResult verify(GenerateOp op) { 316 // Ensure that the tensor type has as many dynamic dimensions as are specified 317 // by the operands. 318 RankedTensorType resultTy = op.getType().cast<RankedTensorType>(); 319 if (op.getNumOperands() != resultTy.getNumDynamicDims()) 320 return op.emitError("must have as many index operands as dynamic extents " 321 "in the result type"); 322 323 // Ensure that region arguments span the index space. 324 if (!llvm::all_of(op.body().getArgumentTypes(), 325 [](Type ty) { return ty.isIndex(); })) 326 return op.emitError("all body arguments must be index"); 327 if (op.body().getNumArguments() != resultTy.getRank()) 328 return op.emitError("must have one body argument per input dimension"); 329 330 // Ensure that the region yields an element of the right type. 331 auto yieldOp = 332 llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator()); 333 if (yieldOp.value().getType() != resultTy.getElementType()) 334 return op.emitOpError( 335 "body must be terminated with a `yield` operation of the tensor " 336 "element type"); 337 338 return success(); 339 } 340 341 void GenerateOp::build( 342 OpBuilder &b, OperationState &result, Type resultTy, 343 ValueRange dynamicExtents, 344 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) { 345 build(b, result, resultTy, dynamicExtents); 346 347 // Build and populate body. 348 OpBuilder::InsertionGuard guard(b); 349 Region *bodyRegion = result.regions.front().get(); 350 auto rank = resultTy.cast<RankedTensorType>().getRank(); 351 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType()); 352 Block *bodyBlock = 353 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes); 354 bodyBuilder(b, result.location, bodyBlock->getArguments()); 355 } 356 357 namespace { 358 359 /// Canonicalizes tensor.generate operations with a constant 360 /// operand into the equivalent operation with the operand expressed in the 361 /// result type, instead. We also insert a type cast to make sure that the 362 /// resulting IR is still well-typed. 363 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> { 364 using OpRewritePattern<GenerateOp>::OpRewritePattern; 365 366 LogicalResult matchAndRewrite(GenerateOp tensorFromElements, 367 PatternRewriter &rewriter) const final { 368 auto resultType = 369 tensorFromElements.getResult().getType().cast<RankedTensorType>(); 370 371 if (resultType.hasStaticShape()) 372 return failure(); 373 374 SmallVector<Value, 4> newOperands; 375 SmallVector<int64_t, 4> newShape; 376 auto operandsIt = tensorFromElements.dynamicExtents().begin(); 377 378 for (int64_t dim : resultType.getShape()) { 379 if (dim != RankedTensorType::kDynamicSize) { 380 newShape.push_back(dim); 381 continue; 382 } 383 APInt index; 384 if (!matchPattern(*operandsIt, m_ConstantInt(&index))) { 385 newShape.push_back(RankedTensorType::kDynamicSize); 386 newOperands.push_back(*operandsIt++); 387 continue; 388 } 389 newShape.push_back(index.getSExtValue()); 390 operandsIt++; 391 } 392 393 if (newOperands.size() == tensorFromElements.dynamicExtents().size()) 394 return failure(); 395 396 auto loc = tensorFromElements.getLoc(); 397 auto newOp = rewriter.create<GenerateOp>( 398 loc, RankedTensorType::get(newShape, resultType.getElementType()), 399 newOperands); 400 rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(), 401 newOp.body().begin()); 402 rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType, 403 newOp); 404 return success(); 405 } 406 }; 407 408 /// Canonicalizes the pattern of the form 409 /// 410 /// %tensor = tensor.generate %x { 411 /// ^bb0(%arg0: index): // no predecessors 412 /// <computation> 413 /// yield %1 : index 414 /// } : tensor<?xindex> 415 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32> 416 /// 417 /// to just <computation> with %arg0 replaced by %c0. We only do this if the 418 /// tensor.generate operation has no side-effects. 419 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> { 420 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 421 422 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 423 PatternRewriter &rewriter) const final { 424 auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>(); 425 if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements)) 426 return failure(); 427 428 BlockAndValueMapping mapping; 429 Block *body = tensorFromElements.getBody(); 430 mapping.map(body->getArguments(), extract.indices()); 431 for (auto &op : body->without_terminator()) 432 rewriter.clone(op, mapping); 433 434 auto yield = cast<YieldOp>(body->getTerminator()); 435 436 rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value())); 437 return success(); 438 } 439 }; 440 441 /// Canonicalizes the pattern of the form 442 /// 443 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32> 444 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32> 445 /// 446 /// to 447 /// 448 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32> 449 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> { 450 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 451 452 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 453 PatternRewriter &rewriter) const final { 454 auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>(); 455 if (!tensorCast) 456 return failure(); 457 458 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(), 459 extract.indices()); 460 return success(); 461 } 462 }; 463 464 } // namespace 465 466 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results, 467 MLIRContext *context) { 468 // TODO: Move extract patterns to tensor::ExtractOp. 469 results.add<ExtractFromTensorGenerate, ExtractFromTensorCast, 470 StaticTensorGenerate>(context); 471 } 472 473 //===----------------------------------------------------------------------===// 474 // ReshapeOp 475 //===----------------------------------------------------------------------===// 476 477 static int64_t GetNumElements(ShapedType type) { 478 int64_t numElements = 1; 479 for (auto dim : type.getShape()) 480 numElements *= dim; 481 return numElements; 482 } 483 484 static LogicalResult verify(ReshapeOp op) { 485 TensorType operandType = op.source().getType().cast<TensorType>(); 486 TensorType resultType = op.result().getType().cast<TensorType>(); 487 488 if (operandType.getElementType() != resultType.getElementType()) 489 return op.emitOpError("element types of source and destination tensor " 490 "types should be the same"); 491 492 int64_t shapeSize = 493 op.shape().getType().cast<RankedTensorType>().getDimSize(0); 494 auto resultRankedType = resultType.dyn_cast<RankedTensorType>(); 495 auto operandRankedType = operandType.dyn_cast<RankedTensorType>(); 496 497 if (resultRankedType) { 498 if (operandRankedType && resultRankedType.hasStaticShape() && 499 operandRankedType.hasStaticShape()) { 500 if (GetNumElements(operandRankedType) != GetNumElements(resultRankedType)) 501 return op.emitOpError("source and destination tensor should have the " 502 "same number of elements"); 503 } 504 if (shapeSize == TensorType::kDynamicSize) 505 return op.emitOpError("cannot use shape operand with dynamic length to " 506 "reshape to statically-ranked tensor type"); 507 if (shapeSize != resultRankedType.getRank()) 508 return op.emitOpError( 509 "length of shape operand differs from the result's tensor rank"); 510 } 511 return success(); 512 } 513 514 //===----------------------------------------------------------------------===// 515 // TableGen'd op method definitions 516 //===----------------------------------------------------------------------===// 517 518 #define GET_OP_CLASSES 519 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc" 520