1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/Matchers.h" 16 17 using namespace mlir; 18 using namespace mlir::bufferization; 19 20 //===----------------------------------------------------------------------===// 21 // Helper functions 22 //===----------------------------------------------------------------------===// 23 24 FailureOr<Value> 25 mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value, 26 MemRefType destType) { 27 auto srcType = value.getType().cast<MemRefType>(); 28 29 // Element type, rank and memory space must match. 30 if (srcType.getElementType() != destType.getElementType()) 31 return failure(); 32 if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt()) 33 return failure(); 34 if (srcType.getRank() != destType.getRank()) 35 return failure(); 36 37 // In case the affine maps are different, we may need to use a copy if we go 38 // from dynamic to static offset or stride (the canonicalization cannot know 39 // at this point that it is really cast compatible). 40 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { 41 int64_t sourceOffset, targetOffset; 42 SmallVector<int64_t, 4> sourceStrides, targetStrides; 43 if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || 44 failed(getStridesAndOffset(target, targetStrides, targetOffset))) 45 return false; 46 auto dynamicToStatic = [](int64_t a, int64_t b) { 47 return a == MemRefType::getDynamicStrideOrOffset() && 48 b != MemRefType::getDynamicStrideOrOffset(); 49 }; 50 if (dynamicToStatic(sourceOffset, targetOffset)) 51 return false; 52 for (auto it : zip(sourceStrides, targetStrides)) 53 if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) 54 return false; 55 return true; 56 }; 57 58 // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To 59 // ensure that we only generate casts that always succeed at runtime, we check 60 // a fix extra conditions in `isGuaranteedCastCompatible`. 61 if (memref::CastOp::areCastCompatible(srcType, destType) && 62 isGuaranteedCastCompatible(srcType, destType)) { 63 Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value); 64 return casted; 65 } 66 67 auto loc = value.getLoc(); 68 SmallVector<Value, 4> dynamicOperands; 69 for (int i = 0; i < destType.getRank(); ++i) { 70 if (destType.getShape()[i] != ShapedType::kDynamicSize) 71 continue; 72 auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i); 73 Value size = b.create<memref::DimOp>(loc, value, index); 74 dynamicOperands.push_back(size); 75 } 76 // TODO: Use alloc/memcpy callback from BufferizationOptions if called via 77 // BufferizableOpInterface impl of ToMemrefOp. 78 Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands); 79 b.create<memref::CopyOp>(loc, value, copy); 80 return copy; 81 } 82 83 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the 84 /// to_memref op are different, a memref.cast is needed. 85 LogicalResult mlir::bufferization::foldToMemrefToTensorPair( 86 RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) { 87 auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>(); 88 if (!memrefToTensor) 89 return failure(); 90 91 Type srcType = memrefToTensor.memref().getType(); 92 Type destType = toMemref.getType(); 93 94 // Directly rewrite if the type did not change. 95 if (srcType == destType) { 96 // Function can be configured to only handle cases where a cast is needed. 97 if (!allowSameType) 98 return failure(); 99 rewriter.replaceOp(toMemref, memrefToTensor.memref()); 100 return success(); 101 } 102 103 auto rankedSrcType = srcType.dyn_cast<MemRefType>(); 104 auto rankedDestType = destType.dyn_cast<MemRefType>(); 105 auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>(); 106 107 // Ranked memref -> Ranked memref cast. 108 if (rankedSrcType && rankedDestType) { 109 FailureOr<Value> replacement = castOrReallocMemRefValue( 110 rewriter, memrefToTensor.memref(), rankedDestType); 111 if (failed(replacement)) 112 return failure(); 113 114 rewriter.replaceOp(toMemref, *replacement); 115 return success(); 116 } 117 118 // Unranked memref -> Ranked memref cast: May require a copy. 119 // TODO: Not implemented at the moment. 120 if (unrankedSrcType && rankedDestType) 121 return failure(); 122 123 // Unranked memref -> unranked memref cast 124 // Ranked memref -> unranked memref cast: No copy needed. 125 assert(memref::CastOp::areCastCompatible(srcType, destType) && 126 "expected that types are cast compatible"); 127 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType, 128 memrefToTensor.memref()); 129 return success(); 130 } 131 132 //===----------------------------------------------------------------------===// 133 // AllocTensorOp 134 //===----------------------------------------------------------------------===// 135 136 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter, 137 BufferizationState &state) { 138 // Nothing to do for dead AllocTensorOps. 139 if (getOperation()->getUses().empty()) 140 return success(); 141 142 FailureOr<Value> alloc = state.createAlloc(rewriter, getLoc(), getResult()); 143 if (failed(alloc)) 144 return failure(); 145 replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc); 146 return success(); 147 } 148 149 LogicalResult AllocTensorOp::verify() { 150 if (getType().getNumDynamicDims() != 151 static_cast<int64_t>(dynamicSizes().size())) 152 return emitError("expected ") 153 << getType().getNumDynamicDims() << " dynamic sizes"; 154 return success(); 155 } 156 157 namespace { 158 /// Change the type of the result of a `bufferization.alloc_tensor` by making 159 /// the result type statically sized along dimension that in the original 160 /// operation where defined as dynamic, but the size was defined using a 161 /// `constant` op. For example: 162 /// 163 /// %c5 = arith.constant 5: index 164 /// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32> 165 /// 166 /// to 167 /// 168 /// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32> 169 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> { 170 using OpRewritePattern<AllocTensorOp>::OpRewritePattern; 171 172 LogicalResult matchAndRewrite(AllocTensorOp op, 173 PatternRewriter &rewriter) const override { 174 SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape()); 175 SmallVector<Value> newDynamicSizes; 176 unsigned int dynValCounter = 0; 177 for (int64_t i = 0; i < op.getType().getRank(); ++i) { 178 if (!op.isDynamicDim(i)) 179 continue; 180 Value value = op.dynamicSizes()[dynValCounter++]; 181 APInt intVal; 182 if (matchPattern(value, m_ConstantInt(&intVal))) { 183 newShape[i] = intVal.getSExtValue(); 184 } else { 185 newDynamicSizes.push_back(value); 186 } 187 } 188 RankedTensorType newType = RankedTensorType::get( 189 newShape, op.getType().getElementType(), op.getType().getEncoding()); 190 if (newType == op.getType()) 191 return failure(); 192 auto newOp = 193 rewriter.create<AllocTensorOp>(op.getLoc(), newType, newDynamicSizes); 194 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 195 return success(); 196 } 197 }; 198 199 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> { 200 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 201 202 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 203 PatternRewriter &rewriter) const override { 204 Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex(); 205 auto allocTensorOp = dimOp.source().getDefiningOp<AllocTensorOp>(); 206 if (!allocTensorOp || !maybeConstantIndex) 207 return failure(); 208 if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex)) 209 return failure(); 210 rewriter.replaceOp(dimOp, 211 allocTensorOp.getDynamicSize(*maybeConstantIndex)); 212 return success(); 213 } 214 }; 215 } // namespace 216 217 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 218 MLIRContext *ctx) { 219 results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx); 220 } 221 222 LogicalResult AllocTensorOp::reifyResultShapes( 223 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 224 auto shapes = llvm::to_vector<4>(llvm::map_range( 225 llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value { 226 if (isDynamicDim(dim)) 227 return getDynamicSize(dim); 228 return builder.create<arith::ConstantIndexOp>(getLoc(), 229 getStaticSize(dim)); 230 })); 231 reifiedReturnShapes.emplace_back(std::move(shapes)); 232 return success(); 233 } 234 235 //===----------------------------------------------------------------------===// 236 // CloneOp 237 //===----------------------------------------------------------------------===// 238 239 void CloneOp::getEffects( 240 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 241 &effects) { 242 effects.emplace_back(MemoryEffects::Read::get(), input(), 243 SideEffects::DefaultResource::get()); 244 effects.emplace_back(MemoryEffects::Write::get(), output(), 245 SideEffects::DefaultResource::get()); 246 effects.emplace_back(MemoryEffects::Allocate::get(), output(), 247 SideEffects::DefaultResource::get()); 248 } 249 250 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 251 return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value(); 252 } 253 254 namespace { 255 256 /// Merge the clone and its source (by converting the clone to a cast) when 257 /// possible. 258 struct SimplifyClones : public OpRewritePattern<CloneOp> { 259 using OpRewritePattern<CloneOp>::OpRewritePattern; 260 261 LogicalResult matchAndRewrite(CloneOp cloneOp, 262 PatternRewriter &rewriter) const override { 263 if (cloneOp.use_empty()) { 264 rewriter.eraseOp(cloneOp); 265 return success(); 266 } 267 268 Value source = cloneOp.input(); 269 270 // This only finds dealloc operations for the immediate value. It should 271 // also consider aliases. That would also make the safety check below 272 // redundant. 273 llvm::Optional<Operation *> maybeCloneDeallocOp = 274 memref::findDealloc(cloneOp.output()); 275 // Skip if either of them has > 1 deallocate operations. 276 if (!maybeCloneDeallocOp.hasValue()) 277 return failure(); 278 llvm::Optional<Operation *> maybeSourceDeallocOp = 279 memref::findDealloc(source); 280 if (!maybeSourceDeallocOp.hasValue()) 281 return failure(); 282 Operation *cloneDeallocOp = *maybeCloneDeallocOp; 283 Operation *sourceDeallocOp = *maybeSourceDeallocOp; 284 285 // If both are deallocated in the same block, their in-block lifetimes 286 // might not fully overlap, so we cannot decide which one to drop. 287 if (cloneDeallocOp && sourceDeallocOp && 288 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 289 return failure(); 290 291 Block *currentBlock = cloneOp->getBlock(); 292 Operation *redundantDealloc = nullptr; 293 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 294 redundantDealloc = cloneDeallocOp; 295 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 296 redundantDealloc = sourceDeallocOp; 297 } 298 299 if (!redundantDealloc) 300 return failure(); 301 302 // Safety check that there are no other deallocations inbetween 303 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 304 // of source before the uses of the clone. With alias information, we could 305 // restrict this to only fail of the dealloc's operand is an alias 306 // of the source. 307 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 308 pos = pos->getNextNode()) { 309 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 310 if (!effectInterface) 311 continue; 312 if (effectInterface.hasEffect<MemoryEffects::Free>()) 313 return failure(); 314 } 315 316 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 317 source); 318 rewriter.eraseOp(redundantDealloc); 319 return success(); 320 } 321 }; 322 323 } // namespace 324 325 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results, 326 MLIRContext *context) { 327 results.add<SimplifyClones>(context); 328 } 329 330 //===----------------------------------------------------------------------===// 331 // ToTensorOp 332 //===----------------------------------------------------------------------===// 333 334 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) { 335 if (auto toMemref = memref().getDefiningOp<ToMemrefOp>()) 336 // Approximate alias analysis by conservatively folding only when no there 337 // is no interleaved operation. 338 if (toMemref->getBlock() == this->getOperation()->getBlock() && 339 toMemref->getNextNode() == this->getOperation()) 340 return toMemref.tensor(); 341 return {}; 342 } 343 344 namespace { 345 346 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> { 347 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 348 349 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 350 PatternRewriter &rewriter) const override { 351 auto memrefToTensorOp = dimOp.source().getDefiningOp<ToTensorOp>(); 352 if (!memrefToTensorOp) 353 return failure(); 354 355 rewriter.replaceOpWithNewOp<memref::DimOp>(dimOp, memrefToTensorOp.memref(), 356 dimOp.index()); 357 return success(); 358 } 359 }; 360 361 } // namespace 362 363 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 364 MLIRContext *context) { 365 results.add<DimOfToTensorFolder>(context); 366 } 367 368 //===----------------------------------------------------------------------===// 369 // ToMemrefOp 370 //===----------------------------------------------------------------------===// 371 372 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) { 373 if (auto memrefToTensor = tensor().getDefiningOp<ToTensorOp>()) 374 if (memrefToTensor.memref().getType() == getType()) 375 return memrefToTensor.memref(); 376 return {}; 377 } 378 379 namespace { 380 381 /// Replace tensor.cast + to_memref by to_memref + memref.cast. 382 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> { 383 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 384 385 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 386 PatternRewriter &rewriter) const final { 387 auto tensorCastOperand = 388 toMemref.getOperand().getDefiningOp<tensor::CastOp>(); 389 if (!tensorCastOperand) 390 return failure(); 391 auto srcTensorType = 392 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 393 if (!srcTensorType) 394 return failure(); 395 auto memrefType = MemRefType::get(srcTensorType.getShape(), 396 srcTensorType.getElementType()); 397 Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType, 398 tensorCastOperand.getOperand()); 399 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(), 400 memref); 401 return success(); 402 } 403 }; 404 405 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to 406 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in. 407 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> { 408 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 409 410 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 411 PatternRewriter &rewriter) const final { 412 // Only handle cases where a cast is needed. The other case is handled by 413 // the folder. 414 return foldToMemrefToTensorPair(rewriter, toMemref, 415 /*allowSameType=*/false); 416 } 417 }; 418 419 /// Fold a load on a to_memref operation into an tensor.extract on the 420 /// corresponding tensor. 421 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> { 422 using OpRewritePattern<memref::LoadOp>::OpRewritePattern; 423 424 LogicalResult matchAndRewrite(memref::LoadOp load, 425 PatternRewriter &rewriter) const override { 426 auto toMemref = load.memref().getDefiningOp<ToMemrefOp>(); 427 if (!toMemref) 428 return failure(); 429 430 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.tensor(), 431 load.indices()); 432 return success(); 433 } 434 }; 435 436 /// Fold dim of a to_memref into the dim of the tensor. 437 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> { 438 using OpRewritePattern<memref::DimOp>::OpRewritePattern; 439 440 LogicalResult matchAndRewrite(memref::DimOp dimOp, 441 PatternRewriter &rewriter) const override { 442 auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>(); 443 if (!castOp) 444 return failure(); 445 Value newSource = castOp.getOperand(); 446 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index()); 447 return success(); 448 } 449 }; 450 451 } // namespace 452 453 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, 454 MLIRContext *context) { 455 results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>( 456 context); 457 } 458 459 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, 460 BufferizationState &state) { 461 // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. 462 (void)foldToMemrefToTensorPair(rewriter, *this); 463 // Note: The return value of `bufferize` indicates whether there was an error 464 // or not. (And not whether the pattern matched or not.) 465 return success(); 466 } 467 468 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { 469 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc) 470 .getOperation(); 471 } 472 473 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) { 474 return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult(); 475 } 476 477 //===----------------------------------------------------------------------===// 478 // TableGen'd op method definitions 479 //===----------------------------------------------------------------------===// 480 481 #define GET_OP_CLASSES 482 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc" 483