1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/SCF/SCF.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/IR/PatternMatch.h" 17 18 using namespace mlir; 19 using namespace mlir::bufferization; 20 using namespace mlir::scf; 21 22 namespace mlir { 23 namespace scf { 24 namespace { 25 26 // bufferization.to_memref is not allowed to change the rank. 27 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 28 #ifndef NDEBUG 29 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 30 assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() == 31 rankedTensorType.getRank())) && 32 "to_memref would be invalid: mismatching ranks"); 33 #endif 34 } 35 36 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not 37 /// fully implemented at the moment. 38 struct ExecuteRegionOpInterface 39 : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface, 40 scf::ExecuteRegionOp> { 41 SmallVector<OpOperand *> 42 getAliasingOpOperand(Operation *op, OpResult opResult, 43 const AnalysisState &state) const { 44 // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be 45 // any SSA value that is in scope. To allow for use-def chain traversal 46 // through ExecuteRegionOps in the analysis, the corresponding yield value 47 // is considered to be aliasing with the result. 48 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 49 size_t resultNum = std::distance(op->getOpResults().begin(), 50 llvm::find(op->getOpResults(), opResult)); 51 // TODO: Support multiple blocks. 52 assert(executeRegionOp.getRegion().getBlocks().size() == 1 && 53 "expected exactly 1 block"); 54 auto yieldOp = dyn_cast<scf::YieldOp>( 55 executeRegionOp.getRegion().front().getTerminator()); 56 assert(yieldOp && "expected scf.yield terminator in scf.execute_region"); 57 return {&yieldOp->getOpOperand(resultNum)}; 58 } 59 60 // TODO: For better bufferization results, this could return `true` only if 61 // there is a memory write in the region. 62 bool isMemoryWrite(Operation *op, OpResult opResult, 63 const AnalysisState &state) const { 64 // Similar to scf.if, results of this op are always considered memory writes 65 // in the analysis. This is a useful pattern for all ops that have tensor 66 // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is 67 // implemented in terms of `bufferizesToMemoryWrite`, which does not work on 68 // ops without OpOperands. 69 return true; 70 } 71 72 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 73 BufferizationState &state) const { 74 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 75 76 // Compute new result types. 77 SmallVector<Type> newResultTypes; 78 for (Type type : executeRegionOp->getResultTypes()) { 79 if (auto tensorType = type.dyn_cast<TensorType>()) { 80 newResultTypes.push_back(getMemRefType(tensorType, state.getOptions())); 81 } else { 82 newResultTypes.push_back(type); 83 } 84 } 85 86 // Create new op and move over region. 87 auto newOp = 88 rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes); 89 newOp.getRegion().takeBody(executeRegionOp.getRegion()); 90 91 // Update terminator. 92 assert(newOp.getRegion().getBlocks().size() == 1 && 93 "only 1 block supported"); 94 Block *newBlock = &newOp.getRegion().front(); 95 auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator()); 96 rewriter.setInsertionPoint(yieldOp); 97 SmallVector<Value> newYieldValues; 98 for (const auto &it : llvm::enumerate(yieldOp.getResults())) { 99 Value val = it.value(); 100 if (val.getType().isa<TensorType>()) { 101 newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>( 102 yieldOp.getLoc(), newResultTypes[it.index()], val)); 103 } else { 104 newYieldValues.push_back(val); 105 } 106 } 107 rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues); 108 109 // Update all uses of the old op. 110 rewriter.setInsertionPointAfter(newOp); 111 SmallVector<Value> newResults; 112 for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { 113 if (it.value().isa<TensorType>()) { 114 newResults.push_back(rewriter.create<bufferization::ToTensorOp>( 115 executeRegionOp.getLoc(), newOp->getResult(it.index()))); 116 } else { 117 newResults.push_back(newOp->getResult(it.index())); 118 } 119 } 120 121 // Replace old op. 122 rewriter.replaceOp(executeRegionOp, newResults); 123 124 return success(); 125 } 126 127 BufferRelation bufferRelation(Operation *op, OpResult opResult, 128 const AnalysisState &state) const { 129 return BufferRelation::Equivalent; 130 } 131 }; 132 133 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. 134 struct IfOpInterface 135 : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> { 136 SmallVector<OpOperand *> 137 getAliasingOpOperand(Operation *op, OpResult opResult, 138 const AnalysisState &state) const { 139 // IfOps do not have tensor OpOperands. The yielded value can be any SSA 140 // value that is in scope. To allow for use-def chain traversal through 141 // IfOps in the analysis, both corresponding yield values from the then/else 142 // branches are considered to be aliasing with the result. 143 auto ifOp = cast<scf::IfOp>(op); 144 size_t resultNum = std::distance(op->getOpResults().begin(), 145 llvm::find(op->getOpResults(), opResult)); 146 return {&ifOp.thenYield()->getOpOperand(resultNum), 147 &ifOp.elseYield()->getOpOperand(resultNum)}; 148 } 149 150 // TODO: For better bufferization results, this could return `true` only if 151 // there is a memory write in one (or both) of the branches. Since this is not 152 // allowed at the moment, we should never encounter scf.ifs that yield 153 // unmodified tensors. Such scf.yield ops could just fold away. 154 bool isMemoryWrite(Operation *op, OpResult opResult, 155 const AnalysisState &state) const { 156 // IfOp results are always considered memory writes in the analysis. This 157 // design decision simplifies the analysis considerably. E.g., consider the 158 // following test case: 159 // 160 // %0 = "some_writing_op" : tensor<?xf32> 161 // %r = scf.if %c -> (tensor<?xf32>) { 162 // scf.yield %0 163 // } else { 164 // %1 = "another_writing_op"(%0) : tensor<?xf32> 165 // } 166 // "some_reading_op"(%r) 167 // 168 // "another_writing_op" in the above example should be able to bufferize 169 // inplace in the absence of another read of %0. However, if the scf.if op 170 // would not be considered a "write", the analysis would detect the 171 // following conflict: 172 // 173 // * read = some_reading_op 174 // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.) 175 // * conflictingWrite = %1 176 // 177 // For more details, check the "scf.IfOp" section of the design document. 178 return true; 179 } 180 181 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 182 BufferizationState &state) const { 183 auto ifOp = cast<scf::IfOp>(op); 184 185 // Compute new types of the bufferized scf.if op. 186 SmallVector<Type> newTypes; 187 for (Type returnType : ifOp->getResultTypes()) { 188 if (auto tensorType = returnType.dyn_cast<TensorType>()) { 189 newTypes.push_back(getMemRefType(tensorType, state.getOptions())); 190 } else { 191 newTypes.push_back(returnType); 192 } 193 } 194 195 // Create new op. 196 auto newIfOp = 197 rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(), 198 /*withElseRegion=*/true); 199 200 // Remove terminators. 201 if (!newIfOp.thenBlock()->empty()) { 202 rewriter.eraseOp(newIfOp.thenBlock()->getTerminator()); 203 rewriter.eraseOp(newIfOp.elseBlock()->getTerminator()); 204 } 205 206 // Move over then/else blocks. 207 rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock()); 208 rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock()); 209 210 // Update scf.yield of new then-block. 211 auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator()); 212 rewriter.setInsertionPoint(thenYieldOp); 213 SmallVector<Value> thenYieldValues; 214 for (OpOperand &operand : thenYieldOp->getOpOperands()) { 215 if (operand.get().getType().isa<TensorType>()) { 216 ensureToMemrefOpIsValid(operand.get(), 217 newTypes[operand.getOperandNumber()]); 218 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 219 operand.get().getLoc(), newTypes[operand.getOperandNumber()], 220 operand.get()); 221 operand.set(toMemrefOp); 222 } 223 } 224 225 // Update scf.yield of new else-block. 226 auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator()); 227 rewriter.setInsertionPoint(elseYieldOp); 228 SmallVector<Value> elseYieldValues; 229 for (OpOperand &operand : elseYieldOp->getOpOperands()) { 230 if (operand.get().getType().isa<TensorType>()) { 231 ensureToMemrefOpIsValid(operand.get(), 232 newTypes[operand.getOperandNumber()]); 233 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 234 operand.get().getLoc(), newTypes[operand.getOperandNumber()], 235 operand.get()); 236 operand.set(toMemrefOp); 237 } 238 } 239 240 // Replace op results. 241 replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); 242 243 return success(); 244 } 245 246 BufferRelation bufferRelation(Operation *op, OpResult opResult, 247 const AnalysisState &state) const { 248 // IfOp results are equivalent to their corresponding yield values if both 249 // yield values are equivalent to each other. 250 auto bufferizableOp = cast<BufferizableOpInterface>(op); 251 SmallVector<OpOperand *> yieldValues = 252 bufferizableOp.getAliasingOpOperand(opResult, state); 253 assert(yieldValues.size() == 2 && "expected 2 yield values"); 254 bool equivalentYields = state.areEquivalentBufferizedValues( 255 yieldValues[0]->get(), yieldValues[1]->get()); 256 return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None; 257 } 258 }; 259 260 /// Bufferization of scf.for. Replace with a new scf.for that operates on 261 /// memrefs. 262 struct ForOpInterface 263 : public BufferizableOpInterface::ExternalModel<ForOpInterface, 264 scf::ForOp> { 265 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 266 const AnalysisState &state) const { 267 // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of 268 // its matching bbArg may. 269 auto forOp = cast<scf::ForOp>(op); 270 return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand)); 271 } 272 273 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 274 const AnalysisState &state) const { 275 // Tensor iter_args of scf::ForOps are always considered as a write. This is 276 // to simplify the analysis. 277 // TODO: Consider doing sth. like isValueWritten. 278 return true; 279 } 280 281 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 282 const AnalysisState &state) const { 283 auto forOp = cast<scf::ForOp>(op); 284 if (!opOperand.get().getType().isa<RankedTensorType>()) 285 return {}; 286 return {forOp.getResultForOpOperand(opOperand)}; 287 } 288 289 BufferRelation bufferRelation(Operation *op, OpResult opResult, 290 const AnalysisState &state) const { 291 // ForOp results are equivalent to their corresponding init_args if the 292 // corresponding iter_args and yield values are equivalent. 293 auto forOp = cast<scf::ForOp>(op); 294 OpOperand &forOperand = forOp.getOpOperandForResult(opResult); 295 auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); 296 auto yieldOp = cast<scf::YieldOp>(&forOp.getLoopBody().front().back()); 297 bool equivalentYield = state.areEquivalentBufferizedValues( 298 bbArg, yieldOp->getOperand(opResult.getResultNumber())); 299 return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None; 300 } 301 302 bool isWritable(Operation *op, Value value, 303 const AnalysisState &state) const { 304 // Interestingly, scf::ForOp's bbArg can **always** be viewed 305 // inplace from the perspective of ops nested under: 306 // 1. Either the matching iter operand is not bufferized inplace and an 307 // alloc + optional copy makes the bbArg itself inplaceable. 308 // 2. Or the matching iter operand is bufferized inplace and bbArg just 309 // bufferizes to that too. 310 return true; 311 } 312 313 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 314 BufferizationState &state) const { 315 auto forOp = cast<scf::ForOp>(op); 316 Block *oldLoopBody = &forOp.getLoopBody().front(); 317 318 // Indices of all iter_args that have tensor type. These are the ones that 319 // are bufferized. 320 DenseSet<int64_t> indices; 321 for (const auto &it : llvm::enumerate(forOp.getInitArgs())) 322 if (it.value().getType().isa<TensorType>()) 323 indices.insert(it.index()); 324 325 // Given a range of values, apply `func` to those marked in `indices`. 326 // Otherwise, store the unmodified value in the result vector. 327 auto convert = [&](ValueRange values, 328 llvm::function_ref<Value(Value, int64_t)> func) { 329 SmallVector<Value> result; 330 for (const auto &it : llvm::enumerate(values)) { 331 size_t idx = it.index(); 332 Value val = it.value(); 333 result.push_back(indices.contains(idx) ? func(val, idx) : val); 334 } 335 return result; 336 }; 337 338 // Construct a new scf.for op with memref instead of tensor values. 339 SmallVector<Value> initArgs; 340 for (OpOperand &opOperand : forOp.getIterOpOperands()) { 341 if (opOperand.get().getType().isa<TensorType>()) { 342 FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand); 343 if (failed(resultBuffer)) 344 return failure(); 345 initArgs.push_back(*resultBuffer); 346 } else { 347 initArgs.push_back(opOperand.get()); 348 } 349 } 350 auto newForOp = rewriter.create<scf::ForOp>( 351 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 352 forOp.getStep(), initArgs); 353 Block *loopBody = &newForOp.getLoopBody().front(); 354 355 // Set up new iter_args. The loop body uses tensors, so wrap the (memref) 356 // iter_args of the new loop in ToTensorOps. 357 rewriter.setInsertionPointToStart(loopBody); 358 SmallVector<Value> iterArgs = 359 convert(newForOp.getRegionIterArgs(), [&](Value val, int64_t index) { 360 return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val); 361 }); 362 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); 363 364 // Erase terminator if present. 365 if (iterArgs.size() == 1) 366 rewriter.eraseOp(loopBody->getTerminator()); 367 368 // Move loop body to new loop. 369 rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs); 370 371 // Update scf.yield of new loop. 372 auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator()); 373 rewriter.setInsertionPoint(yieldOp); 374 SmallVector<Value> yieldValues = 375 convert(yieldOp.getResults(), [&](Value val, int64_t index) { 376 ensureToMemrefOpIsValid(val, initArgs[index].getType()); 377 return rewriter.create<bufferization::ToMemrefOp>( 378 val.getLoc(), initArgs[index].getType(), val); 379 }); 380 yieldOp.getResultsMutable().assign(yieldValues); 381 382 // Replace loop results. 383 replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); 384 385 return success(); 386 } 387 388 /// Assert that yielded values of an scf.for op are aliasing with their 389 /// corresponding bbArgs. This is required because the i-th OpResult of an 390 /// scf.for op is currently assumed to alias with the i-th iter_arg (in the 391 /// absence of conflicts). 392 LogicalResult verifyAnalysis(Operation *op, 393 const AnalysisState &state) const { 394 auto forOp = cast<scf::ForOp>(op); 395 auto yieldOp = 396 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 397 for (OpOperand &operand : yieldOp->getOpOperands()) { 398 auto tensorType = operand.get().getType().dyn_cast<TensorType>(); 399 if (!tensorType) 400 continue; 401 402 OpOperand &forOperand = forOp.getOpOperandForResult( 403 forOp->getResult(operand.getOperandNumber())); 404 auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); 405 // Note: This is overly strict. We should check for aliasing bufferized 406 // values. But we don't have a "must-alias" analysis yet. 407 if (!state.areEquivalentBufferizedValues(operand.get(), bbArg)) 408 // TODO: this could get resolved with copies but it can also turn into 409 // swaps so we need to be careful about order of copies. 410 return yieldOp->emitError() 411 << "Yield operand #" << operand.getOperandNumber() 412 << " does not bufferize to a buffer that is aliasing the " 413 "matching" 414 << " enclosing scf::for operand"; 415 } 416 return success(); 417 } 418 }; 419 420 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so 421 /// this is for analysis only. 422 struct YieldOpInterface 423 : public BufferizableOpInterface::ExternalModel<YieldOpInterface, 424 scf::YieldOp> { 425 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 426 const AnalysisState &state) const { 427 return true; 428 } 429 430 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 431 const AnalysisState &state) const { 432 return false; 433 } 434 435 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 436 const AnalysisState &state) const { 437 if (isa<scf::IfOp>(op->getParentOp())) 438 return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 439 if (isa<scf::ExecuteRegionOp>(op->getParentOp())) 440 return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 441 return {}; 442 } 443 444 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, 445 const AnalysisState &state) const { 446 // Yield operands always bufferize inplace. Otherwise, an alloc + copy 447 // may be generated inside the block. We should not return/yield allocations 448 // when possible. 449 return true; 450 } 451 452 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 453 BufferizationState &state) const { 454 auto yieldOp = cast<scf::YieldOp>(op); 455 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>( 456 yieldOp->getParentOp())) 457 return yieldOp->emitError("unsupported scf::YieldOp parent"); 458 return success(); 459 } 460 }; 461 462 } // namespace 463 } // namespace scf 464 } // namespace mlir 465 466 void mlir::scf::registerBufferizableOpInterfaceExternalModels( 467 DialectRegistry ®istry) { 468 registry.addOpInterface<ExecuteRegionOp, ExecuteRegionOpInterface>(); 469 registry.addOpInterface<ForOp, ForOpInterface>(); 470 registry.addOpInterface<IfOp, IfOpInterface>(); 471 registry.addOpInterface<YieldOp, YieldOpInterface>(); 472 } 473