1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/SCF/SCF.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/IR/PatternMatch.h" 17 18 using namespace mlir; 19 using namespace mlir::bufferization; 20 using namespace mlir::scf; 21 22 namespace mlir { 23 namespace scf { 24 namespace { 25 26 // bufferization.to_memref is not allowed to change the rank. 27 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 28 #ifndef NDEBUG 29 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 30 assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() == 31 rankedTensorType.getRank())) && 32 "to_memref would be invalid: mismatching ranks"); 33 #endif 34 } 35 36 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not 37 /// fully implemented at the moment. 38 struct ExecuteRegionOpInterface 39 : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface, 40 scf::ExecuteRegionOp> { 41 SmallVector<OpOperand *> 42 getAliasingOpOperand(Operation *op, OpResult opResult, 43 const BufferizationState &state) const { 44 // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be 45 // any SSA value that is in scope. To allow for use-def chain traversal 46 // through ExecuteRegionOps in the analysis, the corresponding yield value 47 // is considered to be aliasing with the result. 48 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 49 size_t resultNum = std::distance(op->getOpResults().begin(), 50 llvm::find(op->getOpResults(), opResult)); 51 // TODO: Support multiple blocks. 52 assert(executeRegionOp.getRegion().getBlocks().size() == 1 && 53 "expected exactly 1 block"); 54 auto yieldOp = dyn_cast<scf::YieldOp>( 55 executeRegionOp.getRegion().front().getTerminator()); 56 assert(yieldOp && "expected scf.yield terminator in scf.execute_region"); 57 return {&yieldOp->getOpOperand(resultNum)}; 58 } 59 60 // TODO: For better bufferization results, this could return `true` only if 61 // there is a memory write in the region. 62 bool isMemoryWrite(Operation *op, OpResult opResult, 63 const BufferizationState &state) const { 64 // Similar to scf.if, results of this op are always considered memory writes 65 // in the analysis. This is a useful pattern for all ops that have tensor 66 // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is 67 // implemented in terms of `bufferizesToMemoryWrite`, which does not work on 68 // ops without OpOperands. 69 return true; 70 } 71 72 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 73 const BufferizationState &state) const { 74 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 75 76 // Compute new result types. 77 SmallVector<Type> newResultTypes; 78 for (Type type : executeRegionOp->getResultTypes()) { 79 if (auto tensorType = type.dyn_cast<TensorType>()) { 80 newResultTypes.push_back(getMemRefType(tensorType, state.getOptions())); 81 } else { 82 newResultTypes.push_back(type); 83 } 84 } 85 86 // Create new op and move over region. 87 auto newOp = 88 rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes); 89 newOp.getRegion().takeBody(executeRegionOp.getRegion()); 90 91 // Update terminator. 92 assert(newOp.getRegion().getBlocks().size() == 1 && 93 "only 1 block supported"); 94 Block *newBlock = &newOp.getRegion().front(); 95 auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator()); 96 rewriter.setInsertionPoint(yieldOp); 97 SmallVector<Value> newYieldValues; 98 for (const auto &it : llvm::enumerate(yieldOp.getResults())) { 99 Value val = it.value(); 100 if (val.getType().isa<TensorType>()) { 101 newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>( 102 yieldOp.getLoc(), newResultTypes[it.index()], val)); 103 } else { 104 newYieldValues.push_back(val); 105 } 106 } 107 rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues); 108 109 // Update all uses of the old op. 110 rewriter.setInsertionPointAfter(newOp); 111 SmallVector<Value> newResults; 112 for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { 113 if (it.value().isa<TensorType>()) { 114 newResults.push_back(rewriter.create<bufferization::ToTensorOp>( 115 executeRegionOp.getLoc(), newOp->getResult(it.index()))); 116 } else { 117 newResults.push_back(newOp->getResult(it.index())); 118 } 119 } 120 121 // Replace old op. 122 rewriter.replaceOp(executeRegionOp, newResults); 123 124 return success(); 125 } 126 127 BufferRelation bufferRelation(Operation *op, OpResult opResult, 128 const BufferizationState &state) const { 129 return BufferRelation::Equivalent; 130 } 131 }; 132 133 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. 134 struct IfOpInterface 135 : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> { 136 SmallVector<OpOperand *> 137 getAliasingOpOperand(Operation *op, OpResult opResult, 138 const BufferizationState &state) const { 139 // IfOps do not have tensor OpOperands. The yielded value can be any SSA 140 // value that is in scope. To allow for use-def chain traversal through 141 // IfOps in the analysis, both corresponding yield values from the then/else 142 // branches are considered to be aliasing with the result. 143 auto ifOp = cast<scf::IfOp>(op); 144 size_t resultNum = std::distance(op->getOpResults().begin(), 145 llvm::find(op->getOpResults(), opResult)); 146 return {&ifOp.thenYield()->getOpOperand(resultNum), 147 &ifOp.elseYield()->getOpOperand(resultNum)}; 148 } 149 150 // TODO: For better bufferization results, this could return `true` only if 151 // there is a memory write in one (or both) of the branches. Since this is not 152 // allowed at the moment, we should never encounter scf.ifs that yield 153 // unmodified tensors. Such scf.yield ops could just fold away. 154 bool isMemoryWrite(Operation *op, OpResult opResult, 155 const BufferizationState &state) const { 156 // IfOp results are always considered memory writes in the analysis. This 157 // design decision simplifies the analysis considerably. E.g., consider the 158 // following test case: 159 // 160 // %0 = "some_writing_op" : tensor<?xf32> 161 // %r = scf.if %c -> (tensor<?xf32>) { 162 // scf.yield %0 163 // } else { 164 // %1 = "another_writing_op"(%0) : tensor<?xf32> 165 // } 166 // "some_reading_op"(%r) 167 // 168 // "another_writing_op" in the above example should be able to bufferize 169 // inplace in the absence of another read of %0. However, if the scf.if op 170 // would not be considered a "write", the analysis would detect the 171 // following conflict: 172 // 173 // * read = some_reading_op 174 // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.) 175 // * conflictingWrite = %1 176 // 177 // For more details, check the "scf.IfOp" section of the design document. 178 return true; 179 } 180 181 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 182 const BufferizationState &state) const { 183 auto ifOp = cast<scf::IfOp>(op); 184 185 // Compute new types of the bufferized scf.if op. 186 SmallVector<Type> newTypes; 187 for (Type returnType : ifOp->getResultTypes()) { 188 if (auto tensorType = returnType.dyn_cast<TensorType>()) { 189 newTypes.push_back(getMemRefType(tensorType, state.getOptions())); 190 } else { 191 newTypes.push_back(returnType); 192 } 193 } 194 195 // Create new op. 196 auto newIfOp = 197 rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(), 198 /*withElseRegion=*/true); 199 200 // Remove terminators. 201 if (!newIfOp.thenBlock()->empty()) { 202 rewriter.eraseOp(newIfOp.thenBlock()->getTerminator()); 203 rewriter.eraseOp(newIfOp.elseBlock()->getTerminator()); 204 } 205 206 // Move over then/else blocks. 207 rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock()); 208 rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock()); 209 210 // Update scf.yield of new then-block. 211 auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator()); 212 rewriter.setInsertionPoint(thenYieldOp); 213 SmallVector<Value> thenYieldValues; 214 for (OpOperand &operand : thenYieldOp->getOpOperands()) { 215 if (operand.get().getType().isa<TensorType>()) { 216 ensureToMemrefOpIsValid(operand.get(), 217 newTypes[operand.getOperandNumber()]); 218 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 219 operand.get().getLoc(), newTypes[operand.getOperandNumber()], 220 operand.get()); 221 operand.set(toMemrefOp); 222 } 223 } 224 225 // Update scf.yield of new else-block. 226 auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator()); 227 rewriter.setInsertionPoint(elseYieldOp); 228 SmallVector<Value> elseYieldValues; 229 for (OpOperand &operand : elseYieldOp->getOpOperands()) { 230 if (operand.get().getType().isa<TensorType>()) { 231 ensureToMemrefOpIsValid(operand.get(), 232 newTypes[operand.getOperandNumber()]); 233 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 234 operand.get().getLoc(), newTypes[operand.getOperandNumber()], 235 operand.get()); 236 operand.set(toMemrefOp); 237 } 238 } 239 240 // Replace op results. 241 replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); 242 243 return success(); 244 } 245 246 BufferRelation bufferRelation(Operation *op, OpResult opResult, 247 const BufferizationState &state) const { 248 // IfOp results are equivalent to their corresponding yield values if both 249 // yield values are equivalent to each other. 250 auto bufferizableOp = cast<BufferizableOpInterface>(op); 251 SmallVector<OpOperand *> yieldValues = 252 bufferizableOp.getAliasingOpOperand(opResult, state); 253 assert(yieldValues.size() == 2 && "expected 2 yield values"); 254 bool equivalentYields = state.areEquivalentBufferizedValues( 255 yieldValues[0]->get(), yieldValues[1]->get()); 256 return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None; 257 } 258 }; 259 260 /// Bufferization of scf.for. Replace with a new scf.for that operates on 261 /// memrefs. 262 struct ForOpInterface 263 : public BufferizableOpInterface::ExternalModel<ForOpInterface, 264 scf::ForOp> { 265 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 266 const BufferizationState &state) const { 267 // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of 268 // its matching bbArg may. 269 auto forOp = cast<scf::ForOp>(op); 270 return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand)); 271 } 272 273 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 274 const BufferizationState &state) const { 275 // Tensor iter_args of scf::ForOps are always considered as a write. This is 276 // to simplify the analysis. 277 // TODO: Consider doing sth. like isValueWritten. 278 return true; 279 } 280 281 SmallVector<OpResult> 282 getAliasingOpResult(Operation *op, OpOperand &opOperand, 283 const BufferizationState &state) const { 284 auto forOp = cast<scf::ForOp>(op); 285 if (!opOperand.get().getType().isa<RankedTensorType>()) 286 return {}; 287 return {forOp.getResultForOpOperand(opOperand)}; 288 } 289 290 BufferRelation bufferRelation(Operation *op, OpResult opResult, 291 const BufferizationState &state) const { 292 // ForOp results are equivalent to their corresponding init_args if the 293 // corresponding iter_args and yield values are equivalent. 294 auto forOp = cast<scf::ForOp>(op); 295 OpOperand &forOperand = forOp.getOpOperandForResult(opResult); 296 auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); 297 auto yieldOp = cast<scf::YieldOp>(&forOp.getLoopBody().front().back()); 298 bool equivalentYield = state.areEquivalentBufferizedValues( 299 bbArg, yieldOp->getOperand(opResult.getResultNumber())); 300 return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None; 301 } 302 303 bool isWritable(Operation *op, Value value, 304 const BufferizationState &state) const { 305 // Interestingly, scf::ForOp's bbArg can **always** be viewed 306 // inplace from the perspective of ops nested under: 307 // 1. Either the matching iter operand is not bufferized inplace and an 308 // alloc + optional copy makes the bbArg itself inplaceable. 309 // 2. Or the matching iter operand is bufferized inplace and bbArg just 310 // bufferizes to that too. 311 return true; 312 } 313 314 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 315 const BufferizationState &state) const { 316 auto forOp = cast<scf::ForOp>(op); 317 Block *oldLoopBody = &forOp.getLoopBody().front(); 318 319 // Indices of all iter_args that have tensor type. These are the ones that 320 // are bufferized. 321 DenseSet<int64_t> indices; 322 for (const auto &it : llvm::enumerate(forOp.getInitArgs())) 323 if (it.value().getType().isa<TensorType>()) 324 indices.insert(it.index()); 325 326 // Given a range of values, apply `func` to those marked in `indices`. 327 // Otherwise, store the unmodified value in the result vector. 328 auto convert = [&](ValueRange values, 329 llvm::function_ref<Value(Value, int64_t)> func) { 330 SmallVector<Value> result; 331 for (const auto &it : llvm::enumerate(values)) { 332 size_t idx = it.index(); 333 Value val = it.value(); 334 result.push_back(indices.contains(idx) ? func(val, idx) : val); 335 } 336 return result; 337 }; 338 339 // Construct a new scf.for op with memref instead of tensor values. 340 SmallVector<Value> initArgs; 341 for (OpOperand &opOperand : forOp.getIterOpOperands()) { 342 if (opOperand.get().getType().isa<TensorType>()) { 343 FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand); 344 if (failed(resultBuffer)) 345 return failure(); 346 initArgs.push_back(*resultBuffer); 347 } else { 348 initArgs.push_back(opOperand.get()); 349 } 350 } 351 auto newForOp = rewriter.create<scf::ForOp>( 352 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 353 forOp.getStep(), initArgs); 354 Block *loopBody = &newForOp.getLoopBody().front(); 355 356 // Set up new iter_args. The loop body uses tensors, so wrap the (memref) 357 // iter_args of the new loop in ToTensorOps. 358 rewriter.setInsertionPointToStart(loopBody); 359 SmallVector<Value> iterArgs = 360 convert(newForOp.getRegionIterArgs(), [&](Value val, int64_t index) { 361 return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val); 362 }); 363 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); 364 365 // Erase terminator if present. 366 if (iterArgs.size() == 1) 367 rewriter.eraseOp(loopBody->getTerminator()); 368 369 // Move loop body to new loop. 370 rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs); 371 372 // Update scf.yield of new loop. 373 auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator()); 374 rewriter.setInsertionPoint(yieldOp); 375 SmallVector<Value> yieldValues = 376 convert(yieldOp.getResults(), [&](Value val, int64_t index) { 377 ensureToMemrefOpIsValid(val, initArgs[index].getType()); 378 return rewriter.create<bufferization::ToMemrefOp>( 379 val.getLoc(), initArgs[index].getType(), val); 380 }); 381 yieldOp.getResultsMutable().assign(yieldValues); 382 383 // Replace loop results. 384 replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); 385 386 return success(); 387 } 388 }; 389 390 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so 391 /// this is for analysis only. 392 struct YieldOpInterface 393 : public BufferizableOpInterface::ExternalModel<YieldOpInterface, 394 scf::YieldOp> { 395 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 396 const BufferizationState &state) const { 397 return true; 398 } 399 400 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 401 const BufferizationState &state) const { 402 return false; 403 } 404 405 SmallVector<OpResult> 406 getAliasingOpResult(Operation *op, OpOperand &opOperand, 407 const BufferizationState &state) const { 408 if (isa<scf::IfOp>(op->getParentOp())) 409 return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 410 if (isa<scf::ExecuteRegionOp>(op->getParentOp())) 411 return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 412 return {}; 413 } 414 415 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, 416 const BufferizationState &state) const { 417 // Yield operands always bufferize inplace. Otherwise, an alloc + copy 418 // may be generated inside the block. We should not return/yield allocations 419 // when possible. 420 return true; 421 } 422 423 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 424 const BufferizationState &state) const { 425 auto yieldOp = cast<scf::YieldOp>(op); 426 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>( 427 yieldOp->getParentOp())) 428 return yieldOp->emitError("unsupported scf::YieldOp parent"); 429 return success(); 430 } 431 }; 432 433 } // namespace 434 } // namespace scf 435 } // namespace mlir 436 437 LogicalResult mlir::scf::assertScfForAliasingProperties( 438 Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, 439 SmallVector<Operation *> &newOps) { 440 LogicalResult status = success(); 441 442 op->walk([&](scf::ForOp forOp) { 443 auto yieldOp = 444 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 445 for (OpOperand &operand : yieldOp->getOpOperands()) { 446 auto tensorType = operand.get().getType().dyn_cast<TensorType>(); 447 if (!tensorType) 448 continue; 449 450 OpOperand &forOperand = forOp.getOpOperandForResult( 451 forOp->getResult(operand.getOperandNumber())); 452 auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); 453 // Note: This is overly strict. We should check for aliasing bufferized 454 // values. But we don't have a "must-alias" analysis yet. 455 if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) { 456 // TODO: this could get resolved with copies but it can also turn into 457 // swaps so we need to be careful about order of copies. 458 status = 459 yieldOp->emitError() 460 << "Yield operand #" << operand.getOperandNumber() 461 << " does not bufferize to a buffer that is aliasing the matching" 462 << " enclosing scf::for operand"; 463 return WalkResult::interrupt(); 464 } 465 } 466 return WalkResult::advance(); 467 }); 468 469 return status; 470 } 471 472 void mlir::scf::registerBufferizableOpInterfaceExternalModels( 473 DialectRegistry ®istry) { 474 registry.addOpInterface<ExecuteRegionOp, ExecuteRegionOpInterface>(); 475 registry.addOpInterface<ForOp, ForOpInterface>(); 476 registry.addOpInterface<IfOp, IfOpInterface>(); 477 registry.addOpInterface<YieldOp, YieldOpInterface>(); 478 registry 479 .addOpInterface<ParallelOp, AllocationHoistingBarrierOnly<ParallelOp>>(); 480 } 481