1 2 //===----------------------------------------------------------------------===// 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 12 13 using namespace mlir; 14 using namespace mlir::bufferization; 15 16 //===----------------------------------------------------------------------===// 17 // CloneOp 18 //===----------------------------------------------------------------------===// 19 20 void CloneOp::getEffects( 21 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 22 &effects) { 23 effects.emplace_back(MemoryEffects::Read::get(), input(), 24 SideEffects::DefaultResource::get()); 25 effects.emplace_back(MemoryEffects::Write::get(), output(), 26 SideEffects::DefaultResource::get()); 27 effects.emplace_back(MemoryEffects::Allocate::get(), output(), 28 SideEffects::DefaultResource::get()); 29 } 30 31 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 32 return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value(); 33 } 34 35 namespace { 36 37 /// Merge the clone and its source (by converting the clone to a cast) when 38 /// possible. 39 struct SimplifyClones : public OpRewritePattern<CloneOp> { 40 using OpRewritePattern<CloneOp>::OpRewritePattern; 41 42 LogicalResult matchAndRewrite(CloneOp cloneOp, 43 PatternRewriter &rewriter) const override { 44 if (cloneOp.use_empty()) { 45 rewriter.eraseOp(cloneOp); 46 return success(); 47 } 48 49 Value source = cloneOp.input(); 50 51 // This only finds dealloc operations for the immediate value. It should 52 // also consider aliases. That would also make the safety check below 53 // redundant. 54 llvm::Optional<Operation *> maybeCloneDeallocOp = 55 findDealloc(cloneOp.output()); 56 // Skip if either of them has > 1 deallocate operations. 57 if (!maybeCloneDeallocOp.hasValue()) 58 return failure(); 59 llvm::Optional<Operation *> maybeSourceDeallocOp = findDealloc(source); 60 if (!maybeSourceDeallocOp.hasValue()) 61 return failure(); 62 Operation *cloneDeallocOp = *maybeCloneDeallocOp; 63 Operation *sourceDeallocOp = *maybeSourceDeallocOp; 64 65 // If both are deallocated in the same block, their in-block lifetimes 66 // might not fully overlap, so we cannot decide which one to drop. 67 if (cloneDeallocOp && sourceDeallocOp && 68 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 69 return failure(); 70 71 Block *currentBlock = cloneOp->getBlock(); 72 Operation *redundantDealloc = nullptr; 73 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 74 redundantDealloc = cloneDeallocOp; 75 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 76 redundantDealloc = sourceDeallocOp; 77 } 78 79 if (!redundantDealloc) 80 return failure(); 81 82 // Safety check that there are no other deallocations inbetween 83 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 84 // of source before the uses of the clone. With alias information, we could 85 // restrict this to only fail of the dealloc's operand is an alias 86 // of the source. 87 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 88 pos = pos->getNextNode()) { 89 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 90 if (!effectInterface) 91 continue; 92 if (effectInterface.hasEffect<MemoryEffects::Free>()) 93 return failure(); 94 } 95 96 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 97 source); 98 rewriter.eraseOp(redundantDealloc); 99 return success(); 100 } 101 }; 102 103 } // namespace 104 105 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 106 MLIRContext *context) { 107 results.insert<SimplifyClones>(context); 108 } 109 110 //===----------------------------------------------------------------------===// 111 // ToTensorOp 112 //===----------------------------------------------------------------------===// 113 114 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) { 115 if (auto toMemref = memref().getDefiningOp<ToMemrefOp>()) 116 // Approximate alias analysis by conservatively folding only when no there 117 // is no interleaved operation. 118 if (toMemref->getBlock() == this->getOperation()->getBlock() && 119 toMemref->getNextNode() == this->getOperation()) 120 return toMemref.tensor(); 121 return {}; 122 } 123 124 namespace { 125 126 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> { 127 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 128 129 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 130 PatternRewriter &rewriter) const override { 131 auto memrefToTensorOp = dimOp.source().getDefiningOp<ToTensorOp>(); 132 if (!memrefToTensorOp) 133 return failure(); 134 135 rewriter.replaceOpWithNewOp<memref::DimOp>(dimOp, memrefToTensorOp.memref(), 136 dimOp.index()); 137 return success(); 138 } 139 }; 140 141 } // namespace 142 143 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 144 MLIRContext *context) { 145 results.add<DimOfToTensorFolder>(context); 146 } 147 148 //===----------------------------------------------------------------------===// 149 // ToMemrefOp 150 //===----------------------------------------------------------------------===// 151 152 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) { 153 if (auto memrefToTensor = tensor().getDefiningOp<ToTensorOp>()) 154 if (memrefToTensor.memref().getType() == getType()) 155 return memrefToTensor.memref(); 156 return {}; 157 } 158 159 namespace { 160 161 /// Replace tensor.cast + to_memref by to_memref + memref.cast. 162 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> { 163 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 164 165 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 166 PatternRewriter &rewriter) const final { 167 auto tensorCastOperand = 168 toMemref.getOperand().getDefiningOp<tensor::CastOp>(); 169 if (!tensorCastOperand) 170 return failure(); 171 auto srcTensorType = 172 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 173 if (!srcTensorType) 174 return failure(); 175 auto memrefType = MemRefType::get(srcTensorType.getShape(), 176 srcTensorType.getElementType()); 177 Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType, 178 tensorCastOperand.getOperand()); 179 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(), 180 memref); 181 return success(); 182 } 183 }; 184 185 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to 186 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in. 187 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> { 188 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 189 190 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 191 PatternRewriter &rewriter) const final { 192 auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>(); 193 // Bail unless we have a memref_to_tensor + tensor_to_memref with different 194 // types. `ToMemrefOp::fold` handles the same type case. 195 if (!memrefToTensor || 196 memrefToTensor.memref().getType() == toMemref.getType()) 197 return failure(); 198 // If types are definitely not cast-compatible, bail. 199 if (!memref::CastOp::areCastCompatible(memrefToTensor.memref().getType(), 200 toMemref.getType())) 201 return failure(); 202 203 // We already know that the types are potentially cast-compatible. However 204 // in case the affine maps are different, we may need to use a copy if we go 205 // from dynamic to static offset or stride (the canonicalization cannot know 206 // at this point that it is really cast compatible). 207 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { 208 int64_t sourceOffset, targetOffset; 209 SmallVector<int64_t, 4> sourceStrides, targetStrides; 210 if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || 211 failed(getStridesAndOffset(target, targetStrides, targetOffset))) 212 return false; 213 auto dynamicToStatic = [](int64_t a, int64_t b) { 214 return a == MemRefType::getDynamicStrideOrOffset() && 215 b != MemRefType::getDynamicStrideOrOffset(); 216 }; 217 if (dynamicToStatic(sourceOffset, targetOffset)) 218 return false; 219 for (auto it : zip(sourceStrides, targetStrides)) 220 if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) 221 return false; 222 return true; 223 }; 224 225 auto memrefToTensorType = 226 memrefToTensor.memref().getType().dyn_cast<MemRefType>(); 227 auto toMemrefType = toMemref.getType().dyn_cast<MemRefType>(); 228 if (memrefToTensorType && toMemrefType && 229 !isGuaranteedCastCompatible(memrefToTensorType, toMemrefType)) { 230 MemRefType resultType = toMemrefType; 231 auto loc = toMemref.getLoc(); 232 SmallVector<Value, 4> dynamicOperands; 233 for (int i = 0; i < resultType.getRank(); ++i) { 234 if (resultType.getShape()[i] != ShapedType::kDynamicSize) 235 continue; 236 auto index = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i); 237 Value size = rewriter.create<tensor::DimOp>(loc, memrefToTensor, index); 238 dynamicOperands.push_back(size); 239 } 240 auto copy = 241 rewriter.create<memref::AllocOp>(loc, resultType, dynamicOperands); 242 rewriter.create<memref::CopyOp>(loc, memrefToTensor.memref(), copy); 243 rewriter.replaceOp(toMemref, {copy}); 244 } else 245 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(), 246 memrefToTensor.memref()); 247 return success(); 248 } 249 }; 250 251 /// Fold a load on a to_memref operation into an tensor.extract on the 252 /// corresponding tensor. 253 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> { 254 using OpRewritePattern<memref::LoadOp>::OpRewritePattern; 255 256 LogicalResult matchAndRewrite(memref::LoadOp load, 257 PatternRewriter &rewriter) const override { 258 auto toMemref = load.memref().getDefiningOp<ToMemrefOp>(); 259 if (!toMemref) 260 return failure(); 261 262 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.tensor(), 263 load.indices()); 264 return success(); 265 } 266 }; 267 268 /// Fold dim of a to_memref into the dim of the tensor. 269 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> { 270 using OpRewritePattern<memref::DimOp>::OpRewritePattern; 271 272 LogicalResult matchAndRewrite(memref::DimOp dimOp, 273 PatternRewriter &rewriter) const override { 274 auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>(); 275 if (!castOp) 276 return failure(); 277 Value newSource = castOp.getOperand(); 278 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index()); 279 return success(); 280 } 281 }; 282 283 } // namespace 284 285 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, 286 MLIRContext *context) { 287 results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>( 288 context); 289 } 290 291 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { 292 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc) 293 .getOperation(); 294 } 295 296 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) { 297 return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult(); 298 } 299 300 //===----------------------------------------------------------------------===// 301 // TableGen'd op method definitions 302 //===----------------------------------------------------------------------===// 303 304 #define GET_OP_CLASSES 305 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc" 306