1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 15 using namespace mlir; 16 using namespace mlir::bufferization; 17 18 //===----------------------------------------------------------------------===// 19 // Helper functions 20 //===----------------------------------------------------------------------===// 21 22 FailureOr<Value> 23 mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value, 24 MemRefType destType) { 25 auto srcType = value.getType().cast<MemRefType>(); 26 27 // Element type, rank and memory space must match. 28 if (srcType.getElementType() != destType.getElementType()) 29 return failure(); 30 if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt()) 31 return failure(); 32 if (srcType.getRank() != destType.getRank()) 33 return failure(); 34 35 // In case the affine maps are different, we may need to use a copy if we go 36 // from dynamic to static offset or stride (the canonicalization cannot know 37 // at this point that it is really cast compatible). 38 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { 39 int64_t sourceOffset, targetOffset; 40 SmallVector<int64_t, 4> sourceStrides, targetStrides; 41 if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || 42 failed(getStridesAndOffset(target, targetStrides, targetOffset))) 43 return false; 44 auto dynamicToStatic = [](int64_t a, int64_t b) { 45 return a == MemRefType::getDynamicStrideOrOffset() && 46 b != MemRefType::getDynamicStrideOrOffset(); 47 }; 48 if (dynamicToStatic(sourceOffset, targetOffset)) 49 return false; 50 for (auto it : zip(sourceStrides, targetStrides)) 51 if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) 52 return false; 53 return true; 54 }; 55 56 // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To 57 // ensure that we only generate casts that always succeed at runtime, we check 58 // a fix extra conditions in `isGuaranteedCastCompatible`. 59 if (memref::CastOp::areCastCompatible(srcType, destType) && 60 isGuaranteedCastCompatible(srcType, destType)) { 61 Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value); 62 return casted; 63 } 64 65 auto loc = value.getLoc(); 66 SmallVector<Value, 4> dynamicOperands; 67 for (int i = 0; i < destType.getRank(); ++i) { 68 if (destType.getShape()[i] != ShapedType::kDynamicSize) 69 continue; 70 auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i); 71 Value size = b.create<memref::DimOp>(loc, value, index); 72 dynamicOperands.push_back(size); 73 } 74 // TODO: Use alloc/memcpy callback from BufferizationOptions if called via 75 // BufferizableOpInterface impl of ToMemrefOp. 76 Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands); 77 b.create<memref::CopyOp>(loc, value, copy); 78 return copy; 79 } 80 81 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the 82 /// to_memref op are different, a memref.cast is needed. 83 LogicalResult mlir::bufferization::foldToMemrefToTensorPair( 84 RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) { 85 auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>(); 86 if (!memrefToTensor) 87 return failure(); 88 89 Type srcType = memrefToTensor.memref().getType(); 90 Type destType = toMemref.getType(); 91 92 // Directly rewrite if the type did not change. 93 if (srcType == destType) { 94 // Function can be configured to only handle cases where a cast is needed. 95 if (!allowSameType) 96 return failure(); 97 rewriter.replaceOp(toMemref, memrefToTensor.memref()); 98 return success(); 99 } 100 101 auto rankedSrcType = srcType.dyn_cast<MemRefType>(); 102 auto rankedDestType = destType.dyn_cast<MemRefType>(); 103 auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>(); 104 105 // Ranked memref -> Ranked memref cast. 106 if (rankedSrcType && rankedDestType) { 107 FailureOr<Value> replacement = castOrReallocMemRefValue( 108 rewriter, memrefToTensor.memref(), rankedDestType); 109 if (failed(replacement)) 110 return failure(); 111 112 rewriter.replaceOp(toMemref, *replacement); 113 return success(); 114 } 115 116 // Unranked memref -> Ranked memref cast: May require a copy. 117 // TODO: Not implemented at the moment. 118 if (unrankedSrcType && rankedDestType) 119 return failure(); 120 121 // Unranked memref -> unranked memref cast 122 // Ranked memref -> unranked memref cast: No copy needed. 123 assert(memref::CastOp::areCastCompatible(srcType, destType) && 124 "expected that types are cast compatible"); 125 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType, 126 memrefToTensor.memref()); 127 return success(); 128 } 129 130 //===----------------------------------------------------------------------===// 131 // CloneOp 132 //===----------------------------------------------------------------------===// 133 134 void CloneOp::getEffects( 135 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 136 &effects) { 137 effects.emplace_back(MemoryEffects::Read::get(), input(), 138 SideEffects::DefaultResource::get()); 139 effects.emplace_back(MemoryEffects::Write::get(), output(), 140 SideEffects::DefaultResource::get()); 141 effects.emplace_back(MemoryEffects::Allocate::get(), output(), 142 SideEffects::DefaultResource::get()); 143 } 144 145 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 146 return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value(); 147 } 148 149 namespace { 150 151 /// Merge the clone and its source (by converting the clone to a cast) when 152 /// possible. 153 struct SimplifyClones : public OpRewritePattern<CloneOp> { 154 using OpRewritePattern<CloneOp>::OpRewritePattern; 155 156 LogicalResult matchAndRewrite(CloneOp cloneOp, 157 PatternRewriter &rewriter) const override { 158 if (cloneOp.use_empty()) { 159 rewriter.eraseOp(cloneOp); 160 return success(); 161 } 162 163 Value source = cloneOp.input(); 164 165 // This only finds dealloc operations for the immediate value. It should 166 // also consider aliases. That would also make the safety check below 167 // redundant. 168 llvm::Optional<Operation *> maybeCloneDeallocOp = 169 memref::findDealloc(cloneOp.output()); 170 // Skip if either of them has > 1 deallocate operations. 171 if (!maybeCloneDeallocOp.hasValue()) 172 return failure(); 173 llvm::Optional<Operation *> maybeSourceDeallocOp = 174 memref::findDealloc(source); 175 if (!maybeSourceDeallocOp.hasValue()) 176 return failure(); 177 Operation *cloneDeallocOp = *maybeCloneDeallocOp; 178 Operation *sourceDeallocOp = *maybeSourceDeallocOp; 179 180 // If both are deallocated in the same block, their in-block lifetimes 181 // might not fully overlap, so we cannot decide which one to drop. 182 if (cloneDeallocOp && sourceDeallocOp && 183 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 184 return failure(); 185 186 Block *currentBlock = cloneOp->getBlock(); 187 Operation *redundantDealloc = nullptr; 188 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 189 redundantDealloc = cloneDeallocOp; 190 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 191 redundantDealloc = sourceDeallocOp; 192 } 193 194 if (!redundantDealloc) 195 return failure(); 196 197 // Safety check that there are no other deallocations inbetween 198 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 199 // of source before the uses of the clone. With alias information, we could 200 // restrict this to only fail of the dealloc's operand is an alias 201 // of the source. 202 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 203 pos = pos->getNextNode()) { 204 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 205 if (!effectInterface) 206 continue; 207 if (effectInterface.hasEffect<MemoryEffects::Free>()) 208 return failure(); 209 } 210 211 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 212 source); 213 rewriter.eraseOp(redundantDealloc); 214 return success(); 215 } 216 }; 217 218 } // namespace 219 220 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results, 221 MLIRContext *context) { 222 results.add<SimplifyClones>(context); 223 } 224 225 //===----------------------------------------------------------------------===// 226 // ToTensorOp 227 //===----------------------------------------------------------------------===// 228 229 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) { 230 if (auto toMemref = memref().getDefiningOp<ToMemrefOp>()) 231 // Approximate alias analysis by conservatively folding only when no there 232 // is no interleaved operation. 233 if (toMemref->getBlock() == this->getOperation()->getBlock() && 234 toMemref->getNextNode() == this->getOperation()) 235 return toMemref.tensor(); 236 return {}; 237 } 238 239 namespace { 240 241 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> { 242 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 243 244 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 245 PatternRewriter &rewriter) const override { 246 auto memrefToTensorOp = dimOp.source().getDefiningOp<ToTensorOp>(); 247 if (!memrefToTensorOp) 248 return failure(); 249 250 rewriter.replaceOpWithNewOp<memref::DimOp>(dimOp, memrefToTensorOp.memref(), 251 dimOp.index()); 252 return success(); 253 } 254 }; 255 256 } // namespace 257 258 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 259 MLIRContext *context) { 260 results.add<DimOfToTensorFolder>(context); 261 } 262 263 //===----------------------------------------------------------------------===// 264 // ToMemrefOp 265 //===----------------------------------------------------------------------===// 266 267 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) { 268 if (auto memrefToTensor = tensor().getDefiningOp<ToTensorOp>()) 269 if (memrefToTensor.memref().getType() == getType()) 270 return memrefToTensor.memref(); 271 return {}; 272 } 273 274 namespace { 275 276 /// Replace tensor.cast + to_memref by to_memref + memref.cast. 277 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> { 278 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 279 280 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 281 PatternRewriter &rewriter) const final { 282 auto tensorCastOperand = 283 toMemref.getOperand().getDefiningOp<tensor::CastOp>(); 284 if (!tensorCastOperand) 285 return failure(); 286 auto srcTensorType = 287 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 288 if (!srcTensorType) 289 return failure(); 290 auto memrefType = MemRefType::get(srcTensorType.getShape(), 291 srcTensorType.getElementType()); 292 Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType, 293 tensorCastOperand.getOperand()); 294 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(), 295 memref); 296 return success(); 297 } 298 }; 299 300 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to 301 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in. 302 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> { 303 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 304 305 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 306 PatternRewriter &rewriter) const final { 307 // Only handle cases where a cast is needed. The other case is handled by 308 // the folder. 309 return foldToMemrefToTensorPair(rewriter, toMemref, 310 /*allowSameType=*/false); 311 } 312 }; 313 314 /// Fold a load on a to_memref operation into an tensor.extract on the 315 /// corresponding tensor. 316 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> { 317 using OpRewritePattern<memref::LoadOp>::OpRewritePattern; 318 319 LogicalResult matchAndRewrite(memref::LoadOp load, 320 PatternRewriter &rewriter) const override { 321 auto toMemref = load.memref().getDefiningOp<ToMemrefOp>(); 322 if (!toMemref) 323 return failure(); 324 325 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.tensor(), 326 load.indices()); 327 return success(); 328 } 329 }; 330 331 /// Fold dim of a to_memref into the dim of the tensor. 332 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> { 333 using OpRewritePattern<memref::DimOp>::OpRewritePattern; 334 335 LogicalResult matchAndRewrite(memref::DimOp dimOp, 336 PatternRewriter &rewriter) const override { 337 auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>(); 338 if (!castOp) 339 return failure(); 340 Value newSource = castOp.getOperand(); 341 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index()); 342 return success(); 343 } 344 }; 345 346 } // namespace 347 348 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, 349 MLIRContext *context) { 350 results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>( 351 context); 352 } 353 354 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, 355 BufferizationState &state) { 356 // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. 357 return foldToMemrefToTensorPair(rewriter, *this); 358 } 359 360 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { 361 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc) 362 .getOperation(); 363 } 364 365 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) { 366 return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult(); 367 } 368 369 //===----------------------------------------------------------------------===// 370 // TableGen'd op method definitions 371 //===----------------------------------------------------------------------===// 372 373 #define GET_OP_CLASSES 374 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc" 375