1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/Tensor/IR/Tensor.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/Operation.h" 15 16 using namespace mlir; 17 using namespace mlir::bufferization; 18 using namespace mlir::tensor; 19 20 namespace mlir { 21 namespace tensor { 22 namespace { 23 24 struct CastOpInterface 25 : public BufferizableOpInterface::ExternalModel<CastOpInterface, 26 tensor::CastOp> { 27 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 28 const BufferizationState &state) const { 29 return false; 30 } 31 32 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 33 const BufferizationState &state) const { 34 return false; 35 } 36 37 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 38 const BufferizationState &state) const { 39 return op->getResult(0); 40 } 41 42 BufferRelation bufferRelation(Operation *op, OpResult opResult, 43 const BufferizationState &state) const { 44 return BufferRelation::Equivalent; 45 } 46 47 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 48 const BufferizationState &state) const { 49 auto castOp = cast<tensor::CastOp>(op); 50 51 // The result buffer still has the old (pre-cast) type. 52 FailureOr<Value> resultBuffer = 53 state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/); 54 if (failed(resultBuffer)) 55 return failure(); 56 auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); 57 Attribute memorySpace = sourceMemRefType.getMemorySpace(); 58 TensorType resultTensorType = 59 castOp.getResult().getType().cast<TensorType>(); 60 MemRefLayoutAttrInterface layout; 61 62 if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 63 if (resultTensorType.isa<RankedTensorType>()) 64 layout = rankedMemRefType.getLayout(); 65 66 // Compute the new memref type. 67 Type resultMemRefType; 68 if (resultTensorType.isa<RankedTensorType>()) { 69 resultMemRefType = 70 getContiguousMemRefType(resultTensorType, layout, memorySpace); 71 } else { 72 resultMemRefType = 73 getUnrankedMemRefType(resultTensorType.getElementType(), memorySpace); 74 } 75 76 // Replace the op with a memref.cast. 77 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 78 resultMemRefType) && 79 "CallOp::bufferize: cast incompatible"); 80 replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 81 *resultBuffer); 82 83 return success(); 84 } 85 }; 86 87 /// Bufferization of tensor.dim. Replace with memref.dim. 88 struct DimOpInterface 89 : public BufferizableOpInterface::ExternalModel<DimOpInterface, 90 tensor::DimOp> { 91 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 92 const BufferizationState &state) const { 93 return true; 94 } 95 96 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 97 const BufferizationState &state) const { 98 return false; 99 } 100 101 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 102 const BufferizationState &state) const { 103 return OpResult(); 104 } 105 106 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 107 const BufferizationState &state) const { 108 auto dimOp = cast<tensor::DimOp>(op); 109 Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/); 110 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index()); 111 return success(); 112 } 113 }; 114 115 /// Bufferization of tensor.extract_slice. Replace with memref.subview. 116 struct ExtractSliceOpInterface 117 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 118 tensor::ExtractSliceOp> { 119 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 120 const BufferizationState &state) const { 121 return false; 122 } 123 124 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 125 const BufferizationState &state) const { 126 return false; 127 } 128 129 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 130 const BufferizationState &state) const { 131 return &opOperand == &op->getOpOperand(0) /*source*/ 132 ? op->getResult(0) 133 : OpResult(); 134 } 135 136 BufferRelation bufferRelation(Operation *op, OpResult opResult, 137 const BufferizationState &state) const { 138 return BufferRelation::None; 139 } 140 141 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 142 const BufferizationState &state) const { 143 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 144 Location loc = extractSliceOp.getLoc(); 145 Value srcMemref = 146 *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/, 147 /*forceInPlace=*/true); 148 auto srcMemrefType = srcMemref.getType().cast<MemRefType>(); 149 auto dstTensorType = 150 extractSliceOp.result().getType().cast<RankedTensorType>(); 151 152 // If not inplaceable, alloc. 153 bool inplace = state.isInPlace(extractSliceOp->getOpOperand(0)); 154 Value alloc; 155 if (!inplace) { 156 FailureOr<Value> allocOrFailure = 157 createAlloc(rewriter, loc, extractSliceOp.result(), 158 state.getOptions().createDeallocs, state.getOptions()); 159 if (failed(allocOrFailure)) 160 return failure(); 161 alloc = *allocOrFailure; 162 } 163 164 // Expand offsets, sizes and strides to the full rank to handle the 165 // rank-reducing case. 166 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 167 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 168 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 169 OffsetSizeAndStrideOpInterface::expandToRank( 170 srcMemref, mixedOffsets, mixedSizes, mixedStrides, 171 [&](Value target, int64_t dim) -> OpFoldResult { 172 auto shapedType = target.getType().cast<ShapedType>(); 173 if (shapedType.isDynamicDim(dim)) 174 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 175 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 176 }); 177 // Bufferize to subview. 178 auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 179 dstTensorType.getRank(), srcMemrefType, 180 mixedOffsets, mixedSizes, mixedStrides) 181 .cast<MemRefType>(); 182 Value subView = rewriter.create<memref::SubViewOp>( 183 loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, 184 mixedStrides); 185 186 // If not inplaceable, copy. 187 if (!inplace) { 188 // Do not copy if the copied data is never read. 189 if (state.isValueRead(extractSliceOp.result())) 190 if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView, 191 alloc, state.getOptions()))) 192 return failure(); 193 subView = alloc; 194 } 195 196 replaceOpWithBufferizedValues(rewriter, op, subView); 197 return success(); 198 } 199 }; 200 201 /// Bufferization of tensor.extract. Replace with memref.load. 202 struct ExtractOpInterface 203 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 204 tensor::ExtractOp> { 205 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 206 const BufferizationState &state) const { 207 return true; 208 } 209 210 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 211 const BufferizationState &state) const { 212 return false; 213 } 214 215 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 216 const BufferizationState &state) const { 217 return OpResult(); 218 } 219 220 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 221 const BufferizationState &state) const { 222 auto extractOp = cast<tensor::ExtractOp>(op); 223 Value srcMemref = 224 *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/); 225 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref, 226 extractOp.indices()); 227 return success(); 228 } 229 }; 230 231 /// Bufferization of tensor.insert. Replace with memref.store. 232 struct InsertOpInterface 233 : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 234 tensor::InsertOp> { 235 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 236 const BufferizationState &state) const { 237 return true; 238 } 239 240 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 241 const BufferizationState &state) const { 242 return true; 243 } 244 245 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 246 const BufferizationState &state) const { 247 assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 248 "expected dest OpOperand"); 249 return op->getOpResult(0); 250 } 251 252 SmallVector<OpOperand *> 253 getAliasingOpOperand(Operation *op, OpResult opResult, 254 const BufferizationState &state) const { 255 return {&op->getOpOperand(1) /*dest*/}; 256 } 257 258 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 259 const BufferizationState &state) const { 260 auto insertOp = cast<tensor::InsertOp>(op); 261 FailureOr<Value> destMemref = 262 state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/); 263 if (failed(destMemref)) 264 return failure(); 265 rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(), 266 *destMemref, insertOp.indices()); 267 replaceOpWithBufferizedValues(rewriter, op, *destMemref); 268 return success(); 269 } 270 271 BufferRelation bufferRelation(Operation *op, OpResult opResult, 272 const BufferizationState &state) const { 273 return BufferRelation::Equivalent; 274 } 275 }; 276 277 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 278 /// equivalent operand / result and same offset/sizes/strides specification). 279 /// 280 /// This is one particular type of relationship between ops on tensors that 281 /// reduce to an equivalence on buffers. This should be generalized and 282 /// exposed as interfaces on the proper types. 283 static bool areEquivalentExtractSliceOps(const BufferizationState &state, 284 ExtractSliceOp st, InsertSliceOp sti) { 285 if (!st || !sti) 286 return false; 287 if (sti != sti && 288 !state.areEquivalentBufferizedValues(st.source(), sti.dest())) 289 return false; 290 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 291 return false; 292 return true; 293 } 294 295 /// Return true if `value` is originating from an ExtractSliceOp that matches 296 /// the given InsertSliceOp. 297 static bool hasMatchingExtractSliceOp(const BufferizationState &state, 298 Value value, InsertSliceOp insertOp) { 299 auto condition = [&](Value val) { 300 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 301 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 302 return true; 303 return false; 304 }; 305 306 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 307 condition); 308 } 309 310 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 311 /// certain circumstances, this op can also be a no-op. 312 struct InsertSliceOpInterface 313 : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 314 tensor::InsertSliceOp> { 315 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 316 const BufferizationState &state) const { 317 return true; 318 } 319 320 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 321 const BufferizationState &state) const { 322 return &opOperand == &op->getOpOperand(1) /*dest*/; 323 } 324 325 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 326 const BufferizationState &state) const { 327 return &opOperand == &op->getOpOperand(1) /*dest*/ 328 ? op->getResult(0) 329 : OpResult(); 330 } 331 332 BufferRelation bufferRelation(Operation *op, OpResult opResult, 333 const BufferizationState &state) const { 334 return BufferRelation::Equivalent; 335 } 336 337 bool isNotConflicting(Operation *op, OpOperand *uRead, 338 OpOperand *uConflictingWrite, 339 const BufferizationState &state) const { 340 Operation *readingOp = uRead->getOwner(); 341 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 342 343 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 344 // uRead is an InsertSliceOp... 345 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 346 // As an example, consider the following IR. 347 // 348 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 349 // %1 = linalg.fill %cst, %0 {inplace= [true] } 350 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 351 // {inplace= [true] } 352 353 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 354 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 355 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 356 insertSliceOp)) 357 // Case 1: The main insight is that InsertSliceOp reads only part of 358 // the destination tensor. The overwritten area is not read. If 359 // uConflictingWrite writes into exactly the memory location that is 360 // being read by uRead, this is not a conflict. 361 // 362 // In the above example: 363 // uRead = OpOperand 1 (%t) of tensor.insert_slice 364 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 365 // 366 // The read of %t does not conflict with the write of the FillOp 367 // (same aliases!) because the area that the FillOp operates on is 368 // exactly the one that is *not* read via %t. 369 return true; 370 371 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 372 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 373 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 374 // Case 2: The read of the source tensor and the write to the dest 375 // tensor via an InsertSliceOp is not a conflict if the read is 376 // reading exactly that part of an equivalent tensor that the 377 // InsertSliceOp is writing. 378 // 379 // In the above example: 380 // uRead = OpOperand 0 (%1) of tensor.insert_slice 381 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 382 return true; 383 } 384 385 // If uConflictingWrite is an InsertSliceOp... 386 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 387 // As an example, consider the following IR. 388 // 389 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 390 // %1 = linalg.fill %cst, %0 {inplace= [true] } 391 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 392 // {inplace= [true] } 393 // %3 = vector.transfer_read %1, %cst 394 // 395 // In the above example: 396 // uRead = OpOperand 0 (%1) of vector.transfer_read 397 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 398 // lastWrite = %1 399 // 400 // This is not a conflict because the InsertSliceOp overwrites the 401 // memory segment of %1 with the exact same data. (Effectively, there 402 // is no memory write here.) 403 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 404 state.areEquivalentBufferizedValues(uRead->get(), 405 insertSliceOp.source()) && 406 hasMatchingExtractSliceOp(state, insertSliceOp.source(), 407 insertSliceOp)) 408 return true; 409 410 return false; 411 } 412 413 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 414 const BufferizationState &state) const { 415 // insert_slice ops arise from tiling and bufferizing them out-of-place is 416 // generally a deal breaker. When used with loops, this ends up cloning the 417 // whole tensor on every single iteration and is a symptom of a 418 // catastrophically bad scheduling decision. 419 // TODO: be very loud about it or even consider failing the pass. 420 auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 421 Location loc = insertSliceOp.getLoc(); 422 423 // When bufferizing out-of-place, `getResultBuffer` allocates. 424 FailureOr<Value> dstMemref = 425 state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/); 426 if (failed(dstMemref)) 427 return failure(); 428 429 // Expand offsets, sizes and strides to the full rank to handle the 430 // rank-reducing case. 431 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 432 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 433 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 434 OffsetSizeAndStrideOpInterface::expandToRank( 435 *dstMemref, mixedOffsets, mixedSizes, mixedStrides, 436 [&](Value target, int64_t dim) -> OpFoldResult { 437 auto shapedType = target.getType().cast<ShapedType>(); 438 if (shapedType.isDynamicDim(dim)) 439 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 440 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 441 }); 442 // Take a subview of the dst. 443 auto dstMemrefType = dstMemref->getType().cast<MemRefType>(); 444 auto subviewMemRefType = 445 memref::SubViewOp::inferRankReducedResultType( 446 insertSliceOp.getSourceType().getRank(), dstMemrefType, 447 mixedOffsets, mixedSizes, mixedStrides) 448 .cast<MemRefType>(); 449 Value subView = rewriter.create<memref::SubViewOp>( 450 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 451 mixedStrides); 452 453 // Copy tensor. If this tensor.insert_slice has a matching 454 // tensor.extract_slice, the copy operation will eventually fold away. 455 Value srcMemref = 456 *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); 457 if (failed(createMemCpy(rewriter, loc, srcMemref, subView, 458 state.getOptions()))) 459 return failure(); 460 461 replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 462 return success(); 463 } 464 }; 465 466 /// Bufferization of tensor.rank. Replace with memref.rank. 467 struct RankOpInterface 468 : public BufferizableOpInterface::ExternalModel<RankOpInterface, 469 tensor::RankOp> { 470 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 471 const BufferizationState &state) const { 472 return true; 473 } 474 475 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 476 const BufferizationState &state) const { 477 return false; 478 } 479 480 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 481 const BufferizationState &state) const { 482 return OpResult(); 483 } 484 485 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 486 const BufferizationState &state) const { 487 auto rankOp = cast<tensor::RankOp>(op); 488 Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/); 489 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 490 v); 491 return success(); 492 } 493 }; 494 495 } // namespace 496 } // namespace tensor 497 } // namespace mlir 498 499 void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 500 DialectRegistry ®istry) { 501 registry.addOpInterface<CastOp, CastOpInterface>(); 502 registry.addOpInterface<DimOp, DimOpInterface>(); 503 registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>(); 504 registry.addOpInterface<ExtractOp, ExtractOpInterface>(); 505 registry.addOpInterface<InsertOp, InsertOpInterface>(); 506 registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>(); 507 registry.addOpInterface<RankOp, RankOpInterface>(); 508 } 509