1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 11 12 using namespace mlir; 13 using namespace mlir::bufferization; 14 15 //===----------------------------------------------------------------------===// 16 // Helper functions 17 //===----------------------------------------------------------------------===// 18 19 FailureOr<Value> 20 mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value, 21 MemRefType destType) { 22 auto srcType = value.getType().cast<MemRefType>(); 23 24 // Casting to the same type, nothing to do. 25 if (srcType == destType) 26 return value; 27 28 // Element type, rank and memory space must match. 29 if (srcType.getElementType() != destType.getElementType()) 30 return failure(); 31 if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt()) 32 return failure(); 33 if (srcType.getRank() != destType.getRank()) 34 return failure(); 35 36 // In case the affine maps are different, we may need to use a copy if we go 37 // from dynamic to static offset or stride (the canonicalization cannot know 38 // at this point that it is really cast compatible). 39 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { 40 int64_t sourceOffset, targetOffset; 41 SmallVector<int64_t, 4> sourceStrides, targetStrides; 42 if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || 43 failed(getStridesAndOffset(target, targetStrides, targetOffset))) 44 return false; 45 auto dynamicToStatic = [](int64_t a, int64_t b) { 46 return a == MemRefType::getDynamicStrideOrOffset() && 47 b != MemRefType::getDynamicStrideOrOffset(); 48 }; 49 if (dynamicToStatic(sourceOffset, targetOffset)) 50 return false; 51 for (auto it : zip(sourceStrides, targetStrides)) 52 if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) 53 return false; 54 return true; 55 }; 56 57 // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To 58 // ensure that we only generate casts that always succeed at runtime, we check 59 // a fix extra conditions in `isGuaranteedCastCompatible`. 60 if (memref::CastOp::areCastCompatible(srcType, destType) && 61 isGuaranteedCastCompatible(srcType, destType)) { 62 Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value); 63 return casted; 64 } 65 66 auto loc = value.getLoc(); 67 SmallVector<Value, 4> dynamicOperands; 68 for (int i = 0; i < destType.getRank(); ++i) { 69 if (destType.getShape()[i] != ShapedType::kDynamicSize) 70 continue; 71 auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i); 72 Value size = b.create<memref::DimOp>(loc, value, index); 73 dynamicOperands.push_back(size); 74 } 75 // TODO: Use alloc/memcpy callback from BufferizationOptions if called via 76 // BufferizableOpInterface impl of ToMemrefOp. 77 Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands); 78 b.create<memref::CopyOp>(loc, value, copy); 79 return copy; 80 } 81 82 //===----------------------------------------------------------------------===// 83 // CloneOp 84 //===----------------------------------------------------------------------===// 85 86 void CloneOp::getEffects( 87 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 88 &effects) { 89 effects.emplace_back(MemoryEffects::Read::get(), input(), 90 SideEffects::DefaultResource::get()); 91 effects.emplace_back(MemoryEffects::Write::get(), output(), 92 SideEffects::DefaultResource::get()); 93 effects.emplace_back(MemoryEffects::Allocate::get(), output(), 94 SideEffects::DefaultResource::get()); 95 } 96 97 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 98 return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value(); 99 } 100 101 namespace { 102 103 /// Merge the clone and its source (by converting the clone to a cast) when 104 /// possible. 105 struct SimplifyClones : public OpRewritePattern<CloneOp> { 106 using OpRewritePattern<CloneOp>::OpRewritePattern; 107 108 LogicalResult matchAndRewrite(CloneOp cloneOp, 109 PatternRewriter &rewriter) const override { 110 if (cloneOp.use_empty()) { 111 rewriter.eraseOp(cloneOp); 112 return success(); 113 } 114 115 Value source = cloneOp.input(); 116 117 // This only finds dealloc operations for the immediate value. It should 118 // also consider aliases. That would also make the safety check below 119 // redundant. 120 llvm::Optional<Operation *> maybeCloneDeallocOp = 121 memref::findDealloc(cloneOp.output()); 122 // Skip if either of them has > 1 deallocate operations. 123 if (!maybeCloneDeallocOp.hasValue()) 124 return failure(); 125 llvm::Optional<Operation *> maybeSourceDeallocOp = 126 memref::findDealloc(source); 127 if (!maybeSourceDeallocOp.hasValue()) 128 return failure(); 129 Operation *cloneDeallocOp = *maybeCloneDeallocOp; 130 Operation *sourceDeallocOp = *maybeSourceDeallocOp; 131 132 // If both are deallocated in the same block, their in-block lifetimes 133 // might not fully overlap, so we cannot decide which one to drop. 134 if (cloneDeallocOp && sourceDeallocOp && 135 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 136 return failure(); 137 138 Block *currentBlock = cloneOp->getBlock(); 139 Operation *redundantDealloc = nullptr; 140 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 141 redundantDealloc = cloneDeallocOp; 142 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 143 redundantDealloc = sourceDeallocOp; 144 } 145 146 if (!redundantDealloc) 147 return failure(); 148 149 // Safety check that there are no other deallocations inbetween 150 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 151 // of source before the uses of the clone. With alias information, we could 152 // restrict this to only fail of the dealloc's operand is an alias 153 // of the source. 154 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 155 pos = pos->getNextNode()) { 156 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 157 if (!effectInterface) 158 continue; 159 if (effectInterface.hasEffect<MemoryEffects::Free>()) 160 return failure(); 161 } 162 163 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 164 source); 165 rewriter.eraseOp(redundantDealloc); 166 return success(); 167 } 168 }; 169 170 } // namespace 171 172 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results, 173 MLIRContext *context) { 174 results.add<SimplifyClones>(context); 175 } 176 177 //===----------------------------------------------------------------------===// 178 // ToTensorOp 179 //===----------------------------------------------------------------------===// 180 181 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) { 182 if (auto toMemref = memref().getDefiningOp<ToMemrefOp>()) 183 // Approximate alias analysis by conservatively folding only when no there 184 // is no interleaved operation. 185 if (toMemref->getBlock() == this->getOperation()->getBlock() && 186 toMemref->getNextNode() == this->getOperation()) 187 return toMemref.tensor(); 188 return {}; 189 } 190 191 namespace { 192 193 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> { 194 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 195 196 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 197 PatternRewriter &rewriter) const override { 198 auto memrefToTensorOp = dimOp.source().getDefiningOp<ToTensorOp>(); 199 if (!memrefToTensorOp) 200 return failure(); 201 202 rewriter.replaceOpWithNewOp<memref::DimOp>(dimOp, memrefToTensorOp.memref(), 203 dimOp.index()); 204 return success(); 205 } 206 }; 207 208 } // namespace 209 210 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 211 MLIRContext *context) { 212 results.add<DimOfToTensorFolder>(context); 213 } 214 215 //===----------------------------------------------------------------------===// 216 // ToMemrefOp 217 //===----------------------------------------------------------------------===// 218 219 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) { 220 if (auto memrefToTensor = tensor().getDefiningOp<ToTensorOp>()) 221 if (memrefToTensor.memref().getType() == getType()) 222 return memrefToTensor.memref(); 223 return {}; 224 } 225 226 namespace { 227 228 /// Replace tensor.cast + to_memref by to_memref + memref.cast. 229 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> { 230 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 231 232 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 233 PatternRewriter &rewriter) const final { 234 auto tensorCastOperand = 235 toMemref.getOperand().getDefiningOp<tensor::CastOp>(); 236 if (!tensorCastOperand) 237 return failure(); 238 auto srcTensorType = 239 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 240 if (!srcTensorType) 241 return failure(); 242 auto memrefType = MemRefType::get(srcTensorType.getShape(), 243 srcTensorType.getElementType()); 244 Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType, 245 tensorCastOperand.getOperand()); 246 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(), 247 memref); 248 return success(); 249 } 250 }; 251 252 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the 253 /// to_memref op are different, a memref.cast is needed. 254 static LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, 255 ToMemrefOp toMemref, 256 bool allowSameType = true) { 257 auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>(); 258 if (!memrefToTensor) 259 return failure(); 260 261 Type srcType = memrefToTensor.memref().getType(); 262 Type destType = toMemref.getType(); 263 264 // Function can be configured to only handle cases where a cast is needed. 265 if (!allowSameType && srcType == destType) 266 return failure(); 267 268 auto rankedSrcType = srcType.dyn_cast<MemRefType>(); 269 auto rankedDestType = destType.dyn_cast<MemRefType>(); 270 auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>(); 271 272 // Ranked memref -> Ranked memref cast. 273 if (rankedSrcType && rankedDestType) { 274 FailureOr<Value> replacement = castOrReallocMemRefValue( 275 rewriter, memrefToTensor.memref(), rankedDestType); 276 if (failed(replacement)) 277 return failure(); 278 279 rewriter.replaceOp(toMemref, *replacement); 280 return success(); 281 } 282 283 // Unranked memref -> Ranked memref cast: May require a copy. 284 // TODO: Not implemented at the moment. 285 if (unrankedSrcType && rankedDestType) 286 return failure(); 287 288 // Unranked memref -> unranked memref cast 289 // Ranked memref -> unranked memref cast: No copy needed. 290 assert(memref::CastOp::areCastCompatible(srcType, destType) && 291 "expected that types are cast compatible"); 292 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType, 293 memrefToTensor.memref()); 294 return success(); 295 } 296 297 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to 298 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in. 299 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> { 300 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 301 302 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 303 PatternRewriter &rewriter) const final { 304 // Only handle cases where a cast is needed. The other case is handled by 305 // the folder. 306 return foldToMemrefToTensorPair(rewriter, toMemref, 307 /*allowSameType=*/false); 308 } 309 }; 310 311 /// Fold a load on a to_memref operation into an tensor.extract on the 312 /// corresponding tensor. 313 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> { 314 using OpRewritePattern<memref::LoadOp>::OpRewritePattern; 315 316 LogicalResult matchAndRewrite(memref::LoadOp load, 317 PatternRewriter &rewriter) const override { 318 auto toMemref = load.memref().getDefiningOp<ToMemrefOp>(); 319 if (!toMemref) 320 return failure(); 321 322 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.tensor(), 323 load.indices()); 324 return success(); 325 } 326 }; 327 328 /// Fold dim of a to_memref into the dim of the tensor. 329 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> { 330 using OpRewritePattern<memref::DimOp>::OpRewritePattern; 331 332 LogicalResult matchAndRewrite(memref::DimOp dimOp, 333 PatternRewriter &rewriter) const override { 334 auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>(); 335 if (!castOp) 336 return failure(); 337 Value newSource = castOp.getOperand(); 338 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index()); 339 return success(); 340 } 341 }; 342 343 } // namespace 344 345 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, 346 MLIRContext *context) { 347 results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>( 348 context); 349 } 350 351 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, 352 BufferizationState &state) { 353 // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. 354 return foldToMemrefToTensorPair(rewriter, *this); 355 } 356 357 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { 358 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc) 359 .getOperation(); 360 } 361 362 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) { 363 return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult(); 364 } 365 366 //===----------------------------------------------------------------------===// 367 // TableGen'd op method definitions 368 //===----------------------------------------------------------------------===// 369 370 #define GET_OP_CLASSES 371 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc" 372