1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 11 12 using namespace mlir; 13 using namespace mlir::bufferization; 14 15 //===----------------------------------------------------------------------===// 16 // Helper functions 17 //===----------------------------------------------------------------------===// 18 19 FailureOr<Value> 20 mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value, 21 MemRefType destType) { 22 auto srcType = value.getType().cast<MemRefType>(); 23 24 // Element type, rank and memory space must match. 25 if (srcType.getElementType() != destType.getElementType()) 26 return failure(); 27 if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt()) 28 return failure(); 29 if (srcType.getRank() != destType.getRank()) 30 return failure(); 31 32 // In case the affine maps are different, we may need to use a copy if we go 33 // from dynamic to static offset or stride (the canonicalization cannot know 34 // at this point that it is really cast compatible). 35 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { 36 int64_t sourceOffset, targetOffset; 37 SmallVector<int64_t, 4> sourceStrides, targetStrides; 38 if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || 39 failed(getStridesAndOffset(target, targetStrides, targetOffset))) 40 return false; 41 auto dynamicToStatic = [](int64_t a, int64_t b) { 42 return a == MemRefType::getDynamicStrideOrOffset() && 43 b != MemRefType::getDynamicStrideOrOffset(); 44 }; 45 if (dynamicToStatic(sourceOffset, targetOffset)) 46 return false; 47 for (auto it : zip(sourceStrides, targetStrides)) 48 if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) 49 return false; 50 return true; 51 }; 52 53 // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To 54 // ensure that we only generate casts that always succeed at runtime, we check 55 // a fix extra conditions in `isGuaranteedCastCompatible`. 56 if (memref::CastOp::areCastCompatible(srcType, destType) && 57 isGuaranteedCastCompatible(srcType, destType)) { 58 Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value); 59 return casted; 60 } 61 62 auto loc = value.getLoc(); 63 SmallVector<Value, 4> dynamicOperands; 64 for (int i = 0; i < destType.getRank(); ++i) { 65 if (destType.getShape()[i] != ShapedType::kDynamicSize) 66 continue; 67 auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i); 68 Value size = b.create<memref::DimOp>(loc, value, index); 69 dynamicOperands.push_back(size); 70 } 71 // TODO: Use alloc/memcpy callback from BufferizationOptions if called via 72 // BufferizableOpInterface impl of ToMemrefOp. 73 Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands); 74 b.create<memref::CopyOp>(loc, value, copy); 75 return copy; 76 } 77 78 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the 79 /// to_memref op are different, a memref.cast is needed. 80 LogicalResult mlir::bufferization::foldToMemrefToTensorPair( 81 RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) { 82 auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>(); 83 if (!memrefToTensor) 84 return failure(); 85 86 Type srcType = memrefToTensor.memref().getType(); 87 Type destType = toMemref.getType(); 88 89 // Directly rewrite if the type did not change. 90 if (srcType == destType) { 91 // Function can be configured to only handle cases where a cast is needed. 92 if (!allowSameType) 93 return failure(); 94 rewriter.replaceOp(toMemref, memrefToTensor.memref()); 95 return success(); 96 } 97 98 auto rankedSrcType = srcType.dyn_cast<MemRefType>(); 99 auto rankedDestType = destType.dyn_cast<MemRefType>(); 100 auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>(); 101 102 // Ranked memref -> Ranked memref cast. 103 if (rankedSrcType && rankedDestType) { 104 FailureOr<Value> replacement = castOrReallocMemRefValue( 105 rewriter, memrefToTensor.memref(), rankedDestType); 106 if (failed(replacement)) 107 return failure(); 108 109 rewriter.replaceOp(toMemref, *replacement); 110 return success(); 111 } 112 113 // Unranked memref -> Ranked memref cast: May require a copy. 114 // TODO: Not implemented at the moment. 115 if (unrankedSrcType && rankedDestType) 116 return failure(); 117 118 // Unranked memref -> unranked memref cast 119 // Ranked memref -> unranked memref cast: No copy needed. 120 assert(memref::CastOp::areCastCompatible(srcType, destType) && 121 "expected that types are cast compatible"); 122 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType, 123 memrefToTensor.memref()); 124 return success(); 125 } 126 127 //===----------------------------------------------------------------------===// 128 // CloneOp 129 //===----------------------------------------------------------------------===// 130 131 void CloneOp::getEffects( 132 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 133 &effects) { 134 effects.emplace_back(MemoryEffects::Read::get(), input(), 135 SideEffects::DefaultResource::get()); 136 effects.emplace_back(MemoryEffects::Write::get(), output(), 137 SideEffects::DefaultResource::get()); 138 effects.emplace_back(MemoryEffects::Allocate::get(), output(), 139 SideEffects::DefaultResource::get()); 140 } 141 142 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 143 return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value(); 144 } 145 146 namespace { 147 148 /// Merge the clone and its source (by converting the clone to a cast) when 149 /// possible. 150 struct SimplifyClones : public OpRewritePattern<CloneOp> { 151 using OpRewritePattern<CloneOp>::OpRewritePattern; 152 153 LogicalResult matchAndRewrite(CloneOp cloneOp, 154 PatternRewriter &rewriter) const override { 155 if (cloneOp.use_empty()) { 156 rewriter.eraseOp(cloneOp); 157 return success(); 158 } 159 160 Value source = cloneOp.input(); 161 162 // This only finds dealloc operations for the immediate value. It should 163 // also consider aliases. That would also make the safety check below 164 // redundant. 165 llvm::Optional<Operation *> maybeCloneDeallocOp = 166 memref::findDealloc(cloneOp.output()); 167 // Skip if either of them has > 1 deallocate operations. 168 if (!maybeCloneDeallocOp.hasValue()) 169 return failure(); 170 llvm::Optional<Operation *> maybeSourceDeallocOp = 171 memref::findDealloc(source); 172 if (!maybeSourceDeallocOp.hasValue()) 173 return failure(); 174 Operation *cloneDeallocOp = *maybeCloneDeallocOp; 175 Operation *sourceDeallocOp = *maybeSourceDeallocOp; 176 177 // If both are deallocated in the same block, their in-block lifetimes 178 // might not fully overlap, so we cannot decide which one to drop. 179 if (cloneDeallocOp && sourceDeallocOp && 180 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 181 return failure(); 182 183 Block *currentBlock = cloneOp->getBlock(); 184 Operation *redundantDealloc = nullptr; 185 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 186 redundantDealloc = cloneDeallocOp; 187 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 188 redundantDealloc = sourceDeallocOp; 189 } 190 191 if (!redundantDealloc) 192 return failure(); 193 194 // Safety check that there are no other deallocations inbetween 195 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 196 // of source before the uses of the clone. With alias information, we could 197 // restrict this to only fail of the dealloc's operand is an alias 198 // of the source. 199 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 200 pos = pos->getNextNode()) { 201 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 202 if (!effectInterface) 203 continue; 204 if (effectInterface.hasEffect<MemoryEffects::Free>()) 205 return failure(); 206 } 207 208 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 209 source); 210 rewriter.eraseOp(redundantDealloc); 211 return success(); 212 } 213 }; 214 215 } // namespace 216 217 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results, 218 MLIRContext *context) { 219 results.add<SimplifyClones>(context); 220 } 221 222 //===----------------------------------------------------------------------===// 223 // ToTensorOp 224 //===----------------------------------------------------------------------===// 225 226 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) { 227 if (auto toMemref = memref().getDefiningOp<ToMemrefOp>()) 228 // Approximate alias analysis by conservatively folding only when no there 229 // is no interleaved operation. 230 if (toMemref->getBlock() == this->getOperation()->getBlock() && 231 toMemref->getNextNode() == this->getOperation()) 232 return toMemref.tensor(); 233 return {}; 234 } 235 236 namespace { 237 238 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> { 239 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 240 241 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 242 PatternRewriter &rewriter) const override { 243 auto memrefToTensorOp = dimOp.source().getDefiningOp<ToTensorOp>(); 244 if (!memrefToTensorOp) 245 return failure(); 246 247 rewriter.replaceOpWithNewOp<memref::DimOp>(dimOp, memrefToTensorOp.memref(), 248 dimOp.index()); 249 return success(); 250 } 251 }; 252 253 } // namespace 254 255 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 256 MLIRContext *context) { 257 results.add<DimOfToTensorFolder>(context); 258 } 259 260 //===----------------------------------------------------------------------===// 261 // ToMemrefOp 262 //===----------------------------------------------------------------------===// 263 264 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) { 265 if (auto memrefToTensor = tensor().getDefiningOp<ToTensorOp>()) 266 if (memrefToTensor.memref().getType() == getType()) 267 return memrefToTensor.memref(); 268 return {}; 269 } 270 271 namespace { 272 273 /// Replace tensor.cast + to_memref by to_memref + memref.cast. 274 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> { 275 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 276 277 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 278 PatternRewriter &rewriter) const final { 279 auto tensorCastOperand = 280 toMemref.getOperand().getDefiningOp<tensor::CastOp>(); 281 if (!tensorCastOperand) 282 return failure(); 283 auto srcTensorType = 284 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 285 if (!srcTensorType) 286 return failure(); 287 auto memrefType = MemRefType::get(srcTensorType.getShape(), 288 srcTensorType.getElementType()); 289 Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType, 290 tensorCastOperand.getOperand()); 291 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(), 292 memref); 293 return success(); 294 } 295 }; 296 297 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to 298 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in. 299 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> { 300 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 301 302 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 303 PatternRewriter &rewriter) const final { 304 // Only handle cases where a cast is needed. The other case is handled by 305 // the folder. 306 return foldToMemrefToTensorPair(rewriter, toMemref, 307 /*allowSameType=*/false); 308 } 309 }; 310 311 /// Fold a load on a to_memref operation into an tensor.extract on the 312 /// corresponding tensor. 313 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> { 314 using OpRewritePattern<memref::LoadOp>::OpRewritePattern; 315 316 LogicalResult matchAndRewrite(memref::LoadOp load, 317 PatternRewriter &rewriter) const override { 318 auto toMemref = load.memref().getDefiningOp<ToMemrefOp>(); 319 if (!toMemref) 320 return failure(); 321 322 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.tensor(), 323 load.indices()); 324 return success(); 325 } 326 }; 327 328 /// Fold dim of a to_memref into the dim of the tensor. 329 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> { 330 using OpRewritePattern<memref::DimOp>::OpRewritePattern; 331 332 LogicalResult matchAndRewrite(memref::DimOp dimOp, 333 PatternRewriter &rewriter) const override { 334 auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>(); 335 if (!castOp) 336 return failure(); 337 Value newSource = castOp.getOperand(); 338 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index()); 339 return success(); 340 } 341 }; 342 343 } // namespace 344 345 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, 346 MLIRContext *context) { 347 results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>( 348 context); 349 } 350 351 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, 352 BufferizationState &state) { 353 // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. 354 return foldToMemrefToTensorPair(rewriter, *this); 355 } 356 357 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { 358 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc) 359 .getOperation(); 360 } 361 362 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) { 363 return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult(); 364 } 365 366 //===----------------------------------------------------------------------===// 367 // TableGen'd op method definitions 368 //===----------------------------------------------------------------------===// 369 370 #define GET_OP_CLASSES 371 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc" 372