1 2 //===----------------------------------------------------------------------===// 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 12 13 using namespace mlir; 14 using namespace mlir::bufferization; 15 16 //===----------------------------------------------------------------------===// 17 // CloneOp 18 //===----------------------------------------------------------------------===// 19 20 void CloneOp::getEffects( 21 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 22 &effects) { 23 effects.emplace_back(MemoryEffects::Read::get(), input(), 24 SideEffects::DefaultResource::get()); 25 effects.emplace_back(MemoryEffects::Write::get(), output(), 26 SideEffects::DefaultResource::get()); 27 effects.emplace_back(MemoryEffects::Allocate::get(), output(), 28 SideEffects::DefaultResource::get()); 29 } 30 31 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 32 return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value(); 33 } 34 35 namespace { 36 37 /// Merge the clone and its source (by converting the clone to a cast) when 38 /// possible. 39 struct SimplifyClones : public OpRewritePattern<CloneOp> { 40 using OpRewritePattern<CloneOp>::OpRewritePattern; 41 42 LogicalResult matchAndRewrite(CloneOp cloneOp, 43 PatternRewriter &rewriter) const override { 44 if (cloneOp.use_empty()) { 45 rewriter.eraseOp(cloneOp); 46 return success(); 47 } 48 49 Value source = cloneOp.input(); 50 51 // This only finds dealloc operations for the immediate value. It should 52 // also consider aliases. That would also make the safety check below 53 // redundant. 54 llvm::Optional<Operation *> maybeCloneDeallocOp = 55 findDealloc(cloneOp.output()); 56 // Skip if either of them has > 1 deallocate operations. 57 if (!maybeCloneDeallocOp.hasValue()) 58 return failure(); 59 llvm::Optional<Operation *> maybeSourceDeallocOp = findDealloc(source); 60 if (!maybeSourceDeallocOp.hasValue()) 61 return failure(); 62 Operation *cloneDeallocOp = *maybeCloneDeallocOp; 63 Operation *sourceDeallocOp = *maybeSourceDeallocOp; 64 65 // If both are deallocated in the same block, their in-block lifetimes 66 // might not fully overlap, so we cannot decide which one to drop. 67 if (cloneDeallocOp && sourceDeallocOp && 68 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 69 return failure(); 70 71 Block *currentBlock = cloneOp->getBlock(); 72 Operation *redundantDealloc = nullptr; 73 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 74 redundantDealloc = cloneDeallocOp; 75 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 76 redundantDealloc = sourceDeallocOp; 77 } 78 79 if (!redundantDealloc) 80 return failure(); 81 82 // Safety check that there are no other deallocations inbetween 83 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 84 // of source before the uses of the clone. With alias information, we could 85 // restrict this to only fail of the dealloc's operand is an alias 86 // of the source. 87 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 88 pos = pos->getNextNode()) { 89 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 90 if (!effectInterface) 91 continue; 92 if (effectInterface.hasEffect<MemoryEffects::Free>()) 93 return failure(); 94 } 95 96 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 97 source); 98 rewriter.eraseOp(redundantDealloc); 99 return success(); 100 } 101 }; 102 103 } // namespace 104 105 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results, 106 MLIRContext *context) { 107 results.insert<SimplifyClones>(context); 108 } 109 110 //===----------------------------------------------------------------------===// 111 // ToTensorOp 112 //===----------------------------------------------------------------------===// 113 114 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) { 115 if (auto toMemref = memref().getDefiningOp<ToMemrefOp>()) 116 // Approximate alias analysis by conservatively folding only when no there 117 // is no interleaved operation. 118 if (toMemref->getBlock() == this->getOperation()->getBlock() && 119 toMemref->getNextNode() == this->getOperation()) 120 return toMemref.tensor(); 121 return {}; 122 } 123 124 namespace { 125 126 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> { 127 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 128 129 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 130 PatternRewriter &rewriter) const override { 131 auto memrefToTensorOp = dimOp.source().getDefiningOp<ToTensorOp>(); 132 if (!memrefToTensorOp) 133 return failure(); 134 135 rewriter.replaceOpWithNewOp<memref::DimOp>(dimOp, memrefToTensorOp.memref(), 136 dimOp.index()); 137 return success(); 138 } 139 }; 140 141 } // namespace 142 143 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 144 MLIRContext *context) { 145 results.add<DimOfToTensorFolder>(context); 146 } 147 148 //===----------------------------------------------------------------------===// 149 // ToMemrefOp 150 //===----------------------------------------------------------------------===// 151 152 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) { 153 if (auto memrefToTensor = tensor().getDefiningOp<ToTensorOp>()) 154 if (memrefToTensor.memref().getType() == getType()) 155 return memrefToTensor.memref(); 156 return {}; 157 } 158 159 namespace { 160 161 /// Replace tensor.cast + to_memref by to_memref + memref.cast. 162 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> { 163 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 164 165 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 166 PatternRewriter &rewriter) const final { 167 auto tensorCastOperand = 168 toMemref.getOperand().getDefiningOp<tensor::CastOp>(); 169 if (!tensorCastOperand) 170 return failure(); 171 auto srcTensorType = 172 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 173 if (!srcTensorType) 174 return failure(); 175 auto memrefType = MemRefType::get(srcTensorType.getShape(), 176 srcTensorType.getElementType()); 177 Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType, 178 tensorCastOperand.getOperand()); 179 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(), 180 memref); 181 return success(); 182 } 183 }; 184 185 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the 186 /// to_memref op are different, a memref.cast is needed. 187 static LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, 188 ToMemrefOp toMemref, 189 bool allowSameType = true) { 190 auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>(); 191 if (!memrefToTensor) 192 return failure(); 193 194 // A memref_to_tensor + tensor_to_memref with same types can be folded without 195 // inserting a cast. 196 if (memrefToTensor.memref().getType() == toMemref.getType()) { 197 if (!allowSameType) 198 // Function can be configured to only handle cases where a cast is needed. 199 return failure(); 200 rewriter.replaceOp(toMemref, memrefToTensor.memref()); 201 return success(); 202 } 203 204 // If types are definitely not cast-compatible, bail. 205 if (!memref::CastOp::areCastCompatible(memrefToTensor.memref().getType(), 206 toMemref.getType())) 207 return failure(); 208 209 // We already know that the types are potentially cast-compatible. However 210 // in case the affine maps are different, we may need to use a copy if we go 211 // from dynamic to static offset or stride (the canonicalization cannot know 212 // at this point that it is really cast compatible). 213 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { 214 int64_t sourceOffset, targetOffset; 215 SmallVector<int64_t, 4> sourceStrides, targetStrides; 216 if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || 217 failed(getStridesAndOffset(target, targetStrides, targetOffset))) 218 return false; 219 auto dynamicToStatic = [](int64_t a, int64_t b) { 220 return a == MemRefType::getDynamicStrideOrOffset() && 221 b != MemRefType::getDynamicStrideOrOffset(); 222 }; 223 if (dynamicToStatic(sourceOffset, targetOffset)) 224 return false; 225 for (auto it : zip(sourceStrides, targetStrides)) 226 if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) 227 return false; 228 return true; 229 }; 230 231 auto memrefToTensorType = 232 memrefToTensor.memref().getType().dyn_cast<MemRefType>(); 233 auto toMemrefType = toMemref.getType().dyn_cast<MemRefType>(); 234 if (memrefToTensorType && toMemrefType && 235 !isGuaranteedCastCompatible(memrefToTensorType, toMemrefType)) { 236 MemRefType resultType = toMemrefType; 237 auto loc = toMemref.getLoc(); 238 SmallVector<Value, 4> dynamicOperands; 239 for (int i = 0; i < resultType.getRank(); ++i) { 240 if (resultType.getShape()[i] != ShapedType::kDynamicSize) 241 continue; 242 auto index = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i); 243 Value size = rewriter.create<tensor::DimOp>(loc, memrefToTensor, index); 244 dynamicOperands.push_back(size); 245 } 246 // TODO: Use alloc/memcpy callback from BufferizationOptions if called via 247 // BufferizableOpInterface impl of ToMemrefOp. 248 auto copy = 249 rewriter.create<memref::AllocOp>(loc, resultType, dynamicOperands); 250 rewriter.create<memref::CopyOp>(loc, memrefToTensor.memref(), copy); 251 rewriter.replaceOp(toMemref, {copy}); 252 } else 253 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(), 254 memrefToTensor.memref()); 255 return success(); 256 } 257 258 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to 259 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in. 260 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> { 261 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 262 263 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 264 PatternRewriter &rewriter) const final { 265 // Only handle cases where a cast is needed. The other case is handled by 266 // the folder. 267 return foldToMemrefToTensorPair(rewriter, toMemref, 268 /*allowSameType=*/false); 269 } 270 }; 271 272 /// Fold a load on a to_memref operation into an tensor.extract on the 273 /// corresponding tensor. 274 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> { 275 using OpRewritePattern<memref::LoadOp>::OpRewritePattern; 276 277 LogicalResult matchAndRewrite(memref::LoadOp load, 278 PatternRewriter &rewriter) const override { 279 auto toMemref = load.memref().getDefiningOp<ToMemrefOp>(); 280 if (!toMemref) 281 return failure(); 282 283 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.tensor(), 284 load.indices()); 285 return success(); 286 } 287 }; 288 289 /// Fold dim of a to_memref into the dim of the tensor. 290 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> { 291 using OpRewritePattern<memref::DimOp>::OpRewritePattern; 292 293 LogicalResult matchAndRewrite(memref::DimOp dimOp, 294 PatternRewriter &rewriter) const override { 295 auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>(); 296 if (!castOp) 297 return failure(); 298 Value newSource = castOp.getOperand(); 299 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index()); 300 return success(); 301 } 302 }; 303 304 } // namespace 305 306 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, 307 MLIRContext *context) { 308 results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>( 309 context); 310 } 311 312 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, 313 const BufferizationState &state) { 314 // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. 315 return foldToMemrefToTensorPair(rewriter, *this); 316 } 317 318 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { 319 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc) 320 .getOperation(); 321 } 322 323 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) { 324 return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult(); 325 } 326 327 //===----------------------------------------------------------------------===// 328 // TableGen'd op method definitions 329 //===----------------------------------------------------------------------===// 330 331 #define GET_OP_CLASSES 332 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc" 333