1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 11 12 using namespace mlir; 13 using namespace mlir::bufferization; 14 15 //===----------------------------------------------------------------------===// 16 // Helper functions 17 //===----------------------------------------------------------------------===// 18 19 FailureOr<Value> 20 mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value, 21 MemRefType destType) { 22 auto srcType = value.getType().cast<MemRefType>(); 23 24 // Casting to the same type, nothing to do. 25 if (srcType == destType) 26 return value; 27 28 // Element type, rank and memory space must match. 29 if (srcType.getElementType() != destType.getElementType()) 30 return failure(); 31 if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt()) 32 return failure(); 33 if (srcType.getRank() != destType.getRank()) 34 return failure(); 35 36 // In case the affine maps are different, we may need to use a copy if we go 37 // from dynamic to static offset or stride (the canonicalization cannot know 38 // at this point that it is really cast compatible). 39 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { 40 int64_t sourceOffset, targetOffset; 41 SmallVector<int64_t, 4> sourceStrides, targetStrides; 42 if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || 43 failed(getStridesAndOffset(target, targetStrides, targetOffset))) 44 return false; 45 auto dynamicToStatic = [](int64_t a, int64_t b) { 46 return a == MemRefType::getDynamicStrideOrOffset() && 47 b != MemRefType::getDynamicStrideOrOffset(); 48 }; 49 if (dynamicToStatic(sourceOffset, targetOffset)) 50 return false; 51 for (auto it : zip(sourceStrides, targetStrides)) 52 if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) 53 return false; 54 return true; 55 }; 56 57 // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To 58 // ensure that we only generate casts that always succeed at runtime, we check 59 // a fix extra conditions in `isGuaranteedCastCompatible`. 60 if (memref::CastOp::areCastCompatible(srcType, destType) && 61 isGuaranteedCastCompatible(srcType, destType)) { 62 Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value); 63 return casted; 64 } 65 66 auto loc = value.getLoc(); 67 SmallVector<Value, 4> dynamicOperands; 68 for (int i = 0; i < destType.getRank(); ++i) { 69 if (destType.getShape()[i] != ShapedType::kDynamicSize) 70 continue; 71 auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i); 72 Value size = b.create<memref::DimOp>(loc, value, index); 73 dynamicOperands.push_back(size); 74 } 75 // TODO: Use alloc/memcpy callback from BufferizationOptions if called via 76 // BufferizableOpInterface impl of ToMemrefOp. 77 Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands); 78 b.create<memref::CopyOp>(loc, value, copy); 79 return copy; 80 } 81 82 //===----------------------------------------------------------------------===// 83 // CloneOp 84 //===----------------------------------------------------------------------===// 85 86 void CloneOp::getEffects( 87 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 88 &effects) { 89 effects.emplace_back(MemoryEffects::Read::get(), input(), 90 SideEffects::DefaultResource::get()); 91 effects.emplace_back(MemoryEffects::Write::get(), output(), 92 SideEffects::DefaultResource::get()); 93 effects.emplace_back(MemoryEffects::Allocate::get(), output(), 94 SideEffects::DefaultResource::get()); 95 } 96 97 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 98 return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value(); 99 } 100 101 namespace { 102 103 /// Merge the clone and its source (by converting the clone to a cast) when 104 /// possible. 105 struct SimplifyClones : public OpRewritePattern<CloneOp> { 106 using OpRewritePattern<CloneOp>::OpRewritePattern; 107 108 LogicalResult matchAndRewrite(CloneOp cloneOp, 109 PatternRewriter &rewriter) const override { 110 if (cloneOp.use_empty()) { 111 rewriter.eraseOp(cloneOp); 112 return success(); 113 } 114 115 Value source = cloneOp.input(); 116 117 // This only finds dealloc operations for the immediate value. It should 118 // also consider aliases. That would also make the safety check below 119 // redundant. 120 llvm::Optional<Operation *> maybeCloneDeallocOp = 121 findDealloc(cloneOp.output()); 122 // Skip if either of them has > 1 deallocate operations. 123 if (!maybeCloneDeallocOp.hasValue()) 124 return failure(); 125 llvm::Optional<Operation *> maybeSourceDeallocOp = findDealloc(source); 126 if (!maybeSourceDeallocOp.hasValue()) 127 return failure(); 128 Operation *cloneDeallocOp = *maybeCloneDeallocOp; 129 Operation *sourceDeallocOp = *maybeSourceDeallocOp; 130 131 // If both are deallocated in the same block, their in-block lifetimes 132 // might not fully overlap, so we cannot decide which one to drop. 133 if (cloneDeallocOp && sourceDeallocOp && 134 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 135 return failure(); 136 137 Block *currentBlock = cloneOp->getBlock(); 138 Operation *redundantDealloc = nullptr; 139 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 140 redundantDealloc = cloneDeallocOp; 141 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 142 redundantDealloc = sourceDeallocOp; 143 } 144 145 if (!redundantDealloc) 146 return failure(); 147 148 // Safety check that there are no other deallocations inbetween 149 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 150 // of source before the uses of the clone. With alias information, we could 151 // restrict this to only fail of the dealloc's operand is an alias 152 // of the source. 153 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 154 pos = pos->getNextNode()) { 155 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 156 if (!effectInterface) 157 continue; 158 if (effectInterface.hasEffect<MemoryEffects::Free>()) 159 return failure(); 160 } 161 162 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 163 source); 164 rewriter.eraseOp(redundantDealloc); 165 return success(); 166 } 167 }; 168 169 } // namespace 170 171 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results, 172 MLIRContext *context) { 173 results.add<SimplifyClones>(context); 174 } 175 176 //===----------------------------------------------------------------------===// 177 // ToTensorOp 178 //===----------------------------------------------------------------------===// 179 180 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) { 181 if (auto toMemref = memref().getDefiningOp<ToMemrefOp>()) 182 // Approximate alias analysis by conservatively folding only when no there 183 // is no interleaved operation. 184 if (toMemref->getBlock() == this->getOperation()->getBlock() && 185 toMemref->getNextNode() == this->getOperation()) 186 return toMemref.tensor(); 187 return {}; 188 } 189 190 namespace { 191 192 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> { 193 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 194 195 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 196 PatternRewriter &rewriter) const override { 197 auto memrefToTensorOp = dimOp.source().getDefiningOp<ToTensorOp>(); 198 if (!memrefToTensorOp) 199 return failure(); 200 201 rewriter.replaceOpWithNewOp<memref::DimOp>(dimOp, memrefToTensorOp.memref(), 202 dimOp.index()); 203 return success(); 204 } 205 }; 206 207 } // namespace 208 209 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 210 MLIRContext *context) { 211 results.add<DimOfToTensorFolder>(context); 212 } 213 214 //===----------------------------------------------------------------------===// 215 // ToMemrefOp 216 //===----------------------------------------------------------------------===// 217 218 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) { 219 if (auto memrefToTensor = tensor().getDefiningOp<ToTensorOp>()) 220 if (memrefToTensor.memref().getType() == getType()) 221 return memrefToTensor.memref(); 222 return {}; 223 } 224 225 namespace { 226 227 /// Replace tensor.cast + to_memref by to_memref + memref.cast. 228 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> { 229 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 230 231 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 232 PatternRewriter &rewriter) const final { 233 auto tensorCastOperand = 234 toMemref.getOperand().getDefiningOp<tensor::CastOp>(); 235 if (!tensorCastOperand) 236 return failure(); 237 auto srcTensorType = 238 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 239 if (!srcTensorType) 240 return failure(); 241 auto memrefType = MemRefType::get(srcTensorType.getShape(), 242 srcTensorType.getElementType()); 243 Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType, 244 tensorCastOperand.getOperand()); 245 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(), 246 memref); 247 return success(); 248 } 249 }; 250 251 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the 252 /// to_memref op are different, a memref.cast is needed. 253 static LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, 254 ToMemrefOp toMemref, 255 bool allowSameType = true) { 256 auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>(); 257 if (!memrefToTensor) 258 return failure(); 259 260 Type srcType = memrefToTensor.memref().getType(); 261 Type destType = toMemref.getType(); 262 263 // Function can be configured to only handle cases where a cast is needed. 264 if (!allowSameType && srcType == destType) 265 return failure(); 266 267 auto rankedSrcType = srcType.dyn_cast<MemRefType>(); 268 auto rankedDestType = destType.dyn_cast<MemRefType>(); 269 auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>(); 270 271 // Ranked memref -> Ranked memref cast. 272 if (rankedSrcType && rankedDestType) { 273 FailureOr<Value> replacement = castOrReallocMemRefValue( 274 rewriter, memrefToTensor.memref(), rankedDestType); 275 if (failed(replacement)) 276 return failure(); 277 278 rewriter.replaceOp(toMemref, *replacement); 279 return success(); 280 } 281 282 // Unranked memref -> Ranked memref cast: May require a copy. 283 // TODO: Not implemented at the moment. 284 if (unrankedSrcType && rankedDestType) 285 return failure(); 286 287 // Unranked memref -> unranked memref cast 288 // Ranked memref -> unranked memref cast: No copy needed. 289 assert(memref::CastOp::areCastCompatible(srcType, destType) && 290 "expected that types are cast compatible"); 291 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType, 292 memrefToTensor.memref()); 293 return success(); 294 } 295 296 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to 297 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in. 298 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> { 299 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 300 301 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 302 PatternRewriter &rewriter) const final { 303 // Only handle cases where a cast is needed. The other case is handled by 304 // the folder. 305 return foldToMemrefToTensorPair(rewriter, toMemref, 306 /*allowSameType=*/false); 307 } 308 }; 309 310 /// Fold a load on a to_memref operation into an tensor.extract on the 311 /// corresponding tensor. 312 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> { 313 using OpRewritePattern<memref::LoadOp>::OpRewritePattern; 314 315 LogicalResult matchAndRewrite(memref::LoadOp load, 316 PatternRewriter &rewriter) const override { 317 auto toMemref = load.memref().getDefiningOp<ToMemrefOp>(); 318 if (!toMemref) 319 return failure(); 320 321 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.tensor(), 322 load.indices()); 323 return success(); 324 } 325 }; 326 327 /// Fold dim of a to_memref into the dim of the tensor. 328 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> { 329 using OpRewritePattern<memref::DimOp>::OpRewritePattern; 330 331 LogicalResult matchAndRewrite(memref::DimOp dimOp, 332 PatternRewriter &rewriter) const override { 333 auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>(); 334 if (!castOp) 335 return failure(); 336 Value newSource = castOp.getOperand(); 337 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index()); 338 return success(); 339 } 340 }; 341 342 } // namespace 343 344 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, 345 MLIRContext *context) { 346 results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>( 347 context); 348 } 349 350 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, 351 const BufferizationState &state) { 352 // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. 353 return foldToMemrefToTensorPair(rewriter, *this); 354 } 355 356 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { 357 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc) 358 .getOperation(); 359 } 360 361 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) { 362 return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult(); 363 } 364 365 //===----------------------------------------------------------------------===// 366 // TableGen'd op method definitions 367 //===----------------------------------------------------------------------===// 368 369 #define GET_OP_CLASSES 370 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc" 371