1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 16 using namespace mlir; 17 using namespace mlir::bufferization; 18 19 //===----------------------------------------------------------------------===// 20 // Helper functions 21 //===----------------------------------------------------------------------===// 22 23 FailureOr<Value> 24 mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value, 25 MemRefType destType) { 26 auto srcType = value.getType().cast<MemRefType>(); 27 28 // Element type, rank and memory space must match. 29 if (srcType.getElementType() != destType.getElementType()) 30 return failure(); 31 if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt()) 32 return failure(); 33 if (srcType.getRank() != destType.getRank()) 34 return failure(); 35 36 // In case the affine maps are different, we may need to use a copy if we go 37 // from dynamic to static offset or stride (the canonicalization cannot know 38 // at this point that it is really cast compatible). 39 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { 40 int64_t sourceOffset, targetOffset; 41 SmallVector<int64_t, 4> sourceStrides, targetStrides; 42 if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || 43 failed(getStridesAndOffset(target, targetStrides, targetOffset))) 44 return false; 45 auto dynamicToStatic = [](int64_t a, int64_t b) { 46 return a == MemRefType::getDynamicStrideOrOffset() && 47 b != MemRefType::getDynamicStrideOrOffset(); 48 }; 49 if (dynamicToStatic(sourceOffset, targetOffset)) 50 return false; 51 for (auto it : zip(sourceStrides, targetStrides)) 52 if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) 53 return false; 54 return true; 55 }; 56 57 // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To 58 // ensure that we only generate casts that always succeed at runtime, we check 59 // a fix extra conditions in `isGuaranteedCastCompatible`. 60 if (memref::CastOp::areCastCompatible(srcType, destType) && 61 isGuaranteedCastCompatible(srcType, destType)) { 62 Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value); 63 return casted; 64 } 65 66 auto loc = value.getLoc(); 67 SmallVector<Value, 4> dynamicOperands; 68 for (int i = 0; i < destType.getRank(); ++i) { 69 if (destType.getShape()[i] != ShapedType::kDynamicSize) 70 continue; 71 auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i); 72 Value size = b.create<memref::DimOp>(loc, value, index); 73 dynamicOperands.push_back(size); 74 } 75 // TODO: Use alloc/memcpy callback from BufferizationOptions if called via 76 // BufferizableOpInterface impl of ToMemrefOp. 77 Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands); 78 b.create<memref::CopyOp>(loc, value, copy); 79 return copy; 80 } 81 82 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the 83 /// to_memref op are different, a memref.cast is needed. 84 LogicalResult mlir::bufferization::foldToMemrefToTensorPair( 85 RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) { 86 auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>(); 87 if (!memrefToTensor) 88 return failure(); 89 90 Type srcType = memrefToTensor.memref().getType(); 91 Type destType = toMemref.getType(); 92 93 // Directly rewrite if the type did not change. 94 if (srcType == destType) { 95 // Function can be configured to only handle cases where a cast is needed. 96 if (!allowSameType) 97 return failure(); 98 rewriter.replaceOp(toMemref, memrefToTensor.memref()); 99 return success(); 100 } 101 102 auto rankedSrcType = srcType.dyn_cast<MemRefType>(); 103 auto rankedDestType = destType.dyn_cast<MemRefType>(); 104 auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>(); 105 106 // Ranked memref -> Ranked memref cast. 107 if (rankedSrcType && rankedDestType) { 108 FailureOr<Value> replacement = castOrReallocMemRefValue( 109 rewriter, memrefToTensor.memref(), rankedDestType); 110 if (failed(replacement)) 111 return failure(); 112 113 rewriter.replaceOp(toMemref, *replacement); 114 return success(); 115 } 116 117 // Unranked memref -> Ranked memref cast: May require a copy. 118 // TODO: Not implemented at the moment. 119 if (unrankedSrcType && rankedDestType) 120 return failure(); 121 122 // Unranked memref -> unranked memref cast 123 // Ranked memref -> unranked memref cast: No copy needed. 124 assert(memref::CastOp::areCastCompatible(srcType, destType) && 125 "expected that types are cast compatible"); 126 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType, 127 memrefToTensor.memref()); 128 return success(); 129 } 130 131 //===----------------------------------------------------------------------===// 132 // AllocTensorOp 133 //===----------------------------------------------------------------------===// 134 135 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter, 136 BufferizationState &state) { 137 // Nothing to do for dead AllocTensorOps. 138 if (getOperation()->getUses().empty()) 139 return success(); 140 141 FailureOr<Value> alloc = state.createAlloc(rewriter, getLoc(), getResult()); 142 if (failed(alloc)) 143 return failure(); 144 replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc); 145 return success(); 146 } 147 148 void AllocTensorOp::build(OpBuilder &b, OperationState &result, 149 ArrayRef<OpFoldResult> sizes, Type elementType, 150 ArrayRef<NamedAttribute> attrs) { 151 SmallVector<Value, 4> dynamicSizes; 152 SmallVector<int64_t, 4> staticSizes; 153 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 154 ShapedType::kDynamicSize); 155 auto resultType = RankedTensorType ::get(staticSizes, elementType); 156 build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes)); 157 result.addAttributes(attrs); 158 } 159 160 LogicalResult AllocTensorOp::verify() { 161 RankedTensorType resultType = getType(); 162 SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range( 163 static_sizes().cast<ArrayAttr>(), 164 [](Attribute a) -> int64_t { return a.cast<IntegerAttr>().getInt(); })); 165 166 if (failed(verifyListOfOperandsOrIntegers( 167 *this, "sizes", resultType.getRank(), static_sizes(), sizes(), 168 ShapedType::isDynamic))) 169 return failure(); 170 171 if (static_sizes().size() != static_cast<unsigned>(resultType.getRank())) 172 return emitError("expected ") << resultType.getRank() << " sizes values"; 173 174 Type expectedType = AllocTensorOp::inferResultType( 175 staticSizes, resultType.getElementType(), resultType.getEncoding()); 176 if (resultType != expectedType) { 177 return emitError("specified type ") 178 << resultType << " does not match the inferred type " 179 << expectedType; 180 } 181 return success(); 182 } 183 184 Type AllocTensorOp::inferResultType(ArrayRef<int64_t> staticSizes, 185 Type elementType, Attribute encoding) { 186 return RankedTensorType::get(staticSizes, elementType, encoding); 187 } 188 189 SmallVector<OpFoldResult> AllocTensorOp::getMixedSizes() { 190 SmallVector<OpFoldResult> mixedSizes; 191 mixedSizes.reserve(getType().getRank()); 192 unsigned dynamicValIndex = 0; 193 for (Attribute attr : static_sizes()) { 194 auto intAttr = attr.cast<IntegerAttr>(); 195 if (!ShapedType::isDynamic(intAttr.getInt())) { 196 mixedSizes.push_back(intAttr); 197 continue; 198 } 199 mixedSizes.push_back(sizes()[dynamicValIndex++]); 200 } 201 return mixedSizes; 202 } 203 204 namespace { 205 /// Change the type of the result of a `bufferization.alloc_tensor` by making 206 /// the result type statically sized along dimension that in the original 207 /// operation where defined as dynamic, but the size was defined using a 208 /// `constant` op. For example: 209 /// 210 /// %c5 = arith.constant 5: index 211 /// %0 = bufferization.alloc_tensor [%arg0, %c5] : tensor<?x?xf32> 212 /// 213 /// to 214 /// 215 /// %0 = bufferization.alloc_tensor [%arg0, 5] : tensor<?x5xf32> 216 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> { 217 using OpRewritePattern<AllocTensorOp>::OpRewritePattern; 218 219 LogicalResult matchAndRewrite(AllocTensorOp op, 220 PatternRewriter &rewriter) const override { 221 SmallVector<Value, 4> dynamicSizes; 222 SmallVector<int64_t, 4> staticSizes; 223 for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) { 224 // If the size is already static, nothing to do. 225 if (!op.isDynamicSize(i)) { 226 staticSizes.push_back(op.getStaticSize(i)); 227 continue; 228 } 229 230 // If the size is dynamic but defined using a `constant` op, get the 231 // constant value to find the static size to use. 232 unsigned operandNum = op.getIndexOfDynamicSize(i); 233 Value sizeOperand = op.getOperand(operandNum); 234 if (auto constantIndexOp = 235 sizeOperand.getDefiningOp<arith::ConstantIndexOp>()) { 236 staticSizes.push_back(constantIndexOp.value()); 237 continue; 238 } 239 240 // Fallback case. Keep the size dynamic. 241 dynamicSizes.push_back(sizeOperand); 242 staticSizes.push_back(ShapedType::kDynamicSize); 243 } 244 RankedTensorType newType = 245 RankedTensorType::get(staticSizes, op.getType().getElementType()); 246 if (newType == op.getType()) 247 return failure(); 248 auto newOp = 249 rewriter.create<AllocTensorOp>(op.getLoc(), newType, dynamicSizes, 250 rewriter.getI64ArrayAttr(staticSizes)); 251 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 252 return success(); 253 } 254 }; 255 256 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> { 257 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 258 259 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 260 PatternRewriter &rewriter) const override { 261 Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex(); 262 auto allocTensorOp = dimOp.source().getDefiningOp<AllocTensorOp>(); 263 if (!allocTensorOp || !maybeConstantIndex) 264 return failure(); 265 if (!allocTensorOp.isDynamicSize(*maybeConstantIndex)) 266 return failure(); 267 rewriter.replaceOp(dimOp, 268 allocTensorOp.getDynamicSize(*maybeConstantIndex)); 269 return success(); 270 } 271 }; 272 } // namespace 273 274 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 275 MLIRContext *ctx) { 276 results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx); 277 } 278 279 LogicalResult AllocTensorOp::reifyResultShapes( 280 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 281 auto shapes = llvm::to_vector<4>(llvm::map_range( 282 llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value { 283 if (isDynamicSize(dim)) 284 return getDynamicSize(dim); 285 return builder.create<arith::ConstantIndexOp>(getLoc(), 286 getStaticSize(dim)); 287 })); 288 reifiedReturnShapes.emplace_back(std::move(shapes)); 289 return success(); 290 } 291 292 //===----------------------------------------------------------------------===// 293 // CloneOp 294 //===----------------------------------------------------------------------===// 295 296 void CloneOp::getEffects( 297 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 298 &effects) { 299 effects.emplace_back(MemoryEffects::Read::get(), input(), 300 SideEffects::DefaultResource::get()); 301 effects.emplace_back(MemoryEffects::Write::get(), output(), 302 SideEffects::DefaultResource::get()); 303 effects.emplace_back(MemoryEffects::Allocate::get(), output(), 304 SideEffects::DefaultResource::get()); 305 } 306 307 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 308 return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value(); 309 } 310 311 namespace { 312 313 /// Merge the clone and its source (by converting the clone to a cast) when 314 /// possible. 315 struct SimplifyClones : public OpRewritePattern<CloneOp> { 316 using OpRewritePattern<CloneOp>::OpRewritePattern; 317 318 LogicalResult matchAndRewrite(CloneOp cloneOp, 319 PatternRewriter &rewriter) const override { 320 if (cloneOp.use_empty()) { 321 rewriter.eraseOp(cloneOp); 322 return success(); 323 } 324 325 Value source = cloneOp.input(); 326 327 // This only finds dealloc operations for the immediate value. It should 328 // also consider aliases. That would also make the safety check below 329 // redundant. 330 llvm::Optional<Operation *> maybeCloneDeallocOp = 331 memref::findDealloc(cloneOp.output()); 332 // Skip if either of them has > 1 deallocate operations. 333 if (!maybeCloneDeallocOp.hasValue()) 334 return failure(); 335 llvm::Optional<Operation *> maybeSourceDeallocOp = 336 memref::findDealloc(source); 337 if (!maybeSourceDeallocOp.hasValue()) 338 return failure(); 339 Operation *cloneDeallocOp = *maybeCloneDeallocOp; 340 Operation *sourceDeallocOp = *maybeSourceDeallocOp; 341 342 // If both are deallocated in the same block, their in-block lifetimes 343 // might not fully overlap, so we cannot decide which one to drop. 344 if (cloneDeallocOp && sourceDeallocOp && 345 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 346 return failure(); 347 348 Block *currentBlock = cloneOp->getBlock(); 349 Operation *redundantDealloc = nullptr; 350 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 351 redundantDealloc = cloneDeallocOp; 352 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 353 redundantDealloc = sourceDeallocOp; 354 } 355 356 if (!redundantDealloc) 357 return failure(); 358 359 // Safety check that there are no other deallocations inbetween 360 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 361 // of source before the uses of the clone. With alias information, we could 362 // restrict this to only fail of the dealloc's operand is an alias 363 // of the source. 364 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 365 pos = pos->getNextNode()) { 366 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 367 if (!effectInterface) 368 continue; 369 if (effectInterface.hasEffect<MemoryEffects::Free>()) 370 return failure(); 371 } 372 373 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 374 source); 375 rewriter.eraseOp(redundantDealloc); 376 return success(); 377 } 378 }; 379 380 } // namespace 381 382 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results, 383 MLIRContext *context) { 384 results.add<SimplifyClones>(context); 385 } 386 387 //===----------------------------------------------------------------------===// 388 // ToTensorOp 389 //===----------------------------------------------------------------------===// 390 391 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) { 392 if (auto toMemref = memref().getDefiningOp<ToMemrefOp>()) 393 // Approximate alias analysis by conservatively folding only when no there 394 // is no interleaved operation. 395 if (toMemref->getBlock() == this->getOperation()->getBlock() && 396 toMemref->getNextNode() == this->getOperation()) 397 return toMemref.tensor(); 398 return {}; 399 } 400 401 namespace { 402 403 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> { 404 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 405 406 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 407 PatternRewriter &rewriter) const override { 408 auto memrefToTensorOp = dimOp.source().getDefiningOp<ToTensorOp>(); 409 if (!memrefToTensorOp) 410 return failure(); 411 412 rewriter.replaceOpWithNewOp<memref::DimOp>(dimOp, memrefToTensorOp.memref(), 413 dimOp.index()); 414 return success(); 415 } 416 }; 417 418 } // namespace 419 420 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 421 MLIRContext *context) { 422 results.add<DimOfToTensorFolder>(context); 423 } 424 425 //===----------------------------------------------------------------------===// 426 // ToMemrefOp 427 //===----------------------------------------------------------------------===// 428 429 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) { 430 if (auto memrefToTensor = tensor().getDefiningOp<ToTensorOp>()) 431 if (memrefToTensor.memref().getType() == getType()) 432 return memrefToTensor.memref(); 433 return {}; 434 } 435 436 namespace { 437 438 /// Replace tensor.cast + to_memref by to_memref + memref.cast. 439 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> { 440 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 441 442 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 443 PatternRewriter &rewriter) const final { 444 auto tensorCastOperand = 445 toMemref.getOperand().getDefiningOp<tensor::CastOp>(); 446 if (!tensorCastOperand) 447 return failure(); 448 auto srcTensorType = 449 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 450 if (!srcTensorType) 451 return failure(); 452 auto memrefType = MemRefType::get(srcTensorType.getShape(), 453 srcTensorType.getElementType()); 454 Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType, 455 tensorCastOperand.getOperand()); 456 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(), 457 memref); 458 return success(); 459 } 460 }; 461 462 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to 463 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in. 464 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> { 465 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 466 467 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 468 PatternRewriter &rewriter) const final { 469 // Only handle cases where a cast is needed. The other case is handled by 470 // the folder. 471 return foldToMemrefToTensorPair(rewriter, toMemref, 472 /*allowSameType=*/false); 473 } 474 }; 475 476 /// Fold a load on a to_memref operation into an tensor.extract on the 477 /// corresponding tensor. 478 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> { 479 using OpRewritePattern<memref::LoadOp>::OpRewritePattern; 480 481 LogicalResult matchAndRewrite(memref::LoadOp load, 482 PatternRewriter &rewriter) const override { 483 auto toMemref = load.memref().getDefiningOp<ToMemrefOp>(); 484 if (!toMemref) 485 return failure(); 486 487 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.tensor(), 488 load.indices()); 489 return success(); 490 } 491 }; 492 493 /// Fold dim of a to_memref into the dim of the tensor. 494 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> { 495 using OpRewritePattern<memref::DimOp>::OpRewritePattern; 496 497 LogicalResult matchAndRewrite(memref::DimOp dimOp, 498 PatternRewriter &rewriter) const override { 499 auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>(); 500 if (!castOp) 501 return failure(); 502 Value newSource = castOp.getOperand(); 503 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index()); 504 return success(); 505 } 506 }; 507 508 } // namespace 509 510 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, 511 MLIRContext *context) { 512 results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>( 513 context); 514 } 515 516 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, 517 BufferizationState &state) { 518 // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. 519 (void)foldToMemrefToTensorPair(rewriter, *this); 520 // Note: The return value of `bufferize` indicates whether there was an error 521 // or not. (And not whether the pattern matched or not.) 522 return success(); 523 } 524 525 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { 526 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc) 527 .getOperation(); 528 } 529 530 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) { 531 return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult(); 532 } 533 534 //===----------------------------------------------------------------------===// 535 // TableGen'd op method definitions 536 //===----------------------------------------------------------------------===// 537 538 #define GET_OP_CLASSES 539 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc" 540