1 //===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation pass performs a sparse conditional constant propagation 10 // in MLIR. It identifies values known to be constant, propagates that 11 // information throughout the IR, and replaces them. This is done with an 12 // optimistic dataflow analysis that assumes that all values are constant until 13 // proven otherwise. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "PassDetail.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/Dialect.h" 20 #include "mlir/Interfaces/ControlFlowInterfaces.h" 21 #include "mlir/Interfaces/SideEffectInterfaces.h" 22 #include "mlir/Pass/Pass.h" 23 #include "mlir/Transforms/FoldUtils.h" 24 #include "mlir/Transforms/Passes.h" 25 26 using namespace mlir; 27 28 namespace { 29 /// This class represents a single lattice value. A lattive value corresponds to 30 /// the various different states that a value in the SCCP dataflow analysis can 31 /// take. See 'Kind' below for more details on the different states a value can 32 /// take. 33 class LatticeValue { 34 enum Kind { 35 /// A value with a yet to be determined value. This state may be changed to 36 /// anything. 37 Unknown, 38 39 /// A value that is known to be a constant. This state may be changed to 40 /// overdefined. 41 Constant, 42 43 /// A value that cannot statically be determined to be a constant. This 44 /// state cannot be changed. 45 Overdefined 46 }; 47 48 public: 49 /// Initialize a lattice value with "Unknown". 50 LatticeValue() 51 : constantAndTag(nullptr, Kind::Unknown), constantDialect(nullptr) {} 52 /// Initialize a lattice value with a constant. 53 LatticeValue(Attribute attr, Dialect *dialect) 54 : constantAndTag(attr, Kind::Constant), constantDialect(dialect) {} 55 56 /// Returns true if this lattice value is unknown. 57 bool isUnknown() const { return constantAndTag.getInt() == Kind::Unknown; } 58 59 /// Mark the lattice value as overdefined. 60 void markOverdefined() { 61 constantAndTag.setPointerAndInt(nullptr, Kind::Overdefined); 62 constantDialect = nullptr; 63 } 64 65 /// Returns true if the lattice is overdefined. 66 bool isOverdefined() const { 67 return constantAndTag.getInt() == Kind::Overdefined; 68 } 69 70 /// Mark the lattice value as constant. 71 void markConstant(Attribute value, Dialect *dialect) { 72 constantAndTag.setPointerAndInt(value, Kind::Constant); 73 constantDialect = dialect; 74 } 75 76 /// If this lattice is constant, return the constant. Returns nullptr 77 /// otherwise. 78 Attribute getConstant() const { return constantAndTag.getPointer(); } 79 80 /// If this lattice is constant, return the dialect to use when materializing 81 /// the constant. 82 Dialect *getConstantDialect() const { 83 assert(getConstant() && "expected valid constant"); 84 return constantDialect; 85 } 86 87 /// Merge in the value of the 'rhs' lattice into this one. Returns true if the 88 /// lattice value changed. 89 bool meet(const LatticeValue &rhs) { 90 // If we are already overdefined, or rhs is unknown, there is nothing to do. 91 if (isOverdefined() || rhs.isUnknown()) 92 return false; 93 // If we are unknown, just take the value of rhs. 94 if (isUnknown()) { 95 constantAndTag = rhs.constantAndTag; 96 constantDialect = rhs.constantDialect; 97 return true; 98 } 99 100 // Otherwise, if this value doesn't match rhs go straight to overdefined. 101 if (constantAndTag != rhs.constantAndTag) { 102 markOverdefined(); 103 return true; 104 } 105 return false; 106 } 107 108 private: 109 /// The attribute value if this is a constant and the tag for the element 110 /// kind. 111 llvm::PointerIntPair<Attribute, 2, Kind> constantAndTag; 112 113 /// The dialect the constant originated from. This is only valid if the 114 /// lattice is a constant. This is not used as part of the key, and is only 115 /// needed to materialize the held constant if necessary. 116 Dialect *constantDialect; 117 }; 118 119 /// This class contains various state used when computing the lattice of a 120 /// callable operation. 121 class CallableLatticeState { 122 public: 123 /// Build a lattice state with a given callable region, and a specified number 124 /// of results to be initialized to the default lattice value (Unknown). 125 CallableLatticeState(Region *callableRegion, unsigned numResults) 126 : callableArguments(callableRegion->getArguments()), 127 resultLatticeValues(numResults) {} 128 129 /// Returns the arguments to the callable region. 130 Block::BlockArgListType getCallableArguments() const { 131 return callableArguments; 132 } 133 134 /// Returns the lattice value for the results of the callable region. 135 MutableArrayRef<LatticeValue> getResultLatticeValues() { 136 return resultLatticeValues; 137 } 138 139 /// Add a call to this callable. This is only used if the callable defines a 140 /// symbol. 141 void addSymbolCall(Operation *op) { symbolCalls.push_back(op); } 142 143 /// Return the calls that reference this callable. This is only used 144 /// if the callable defines a symbol. 145 ArrayRef<Operation *> getSymbolCalls() const { return symbolCalls; } 146 147 private: 148 /// The arguments of the callable region. 149 Block::BlockArgListType callableArguments; 150 151 /// The lattice state for each of the results of this region. The return 152 /// values of the callable aren't SSA values, so we need to track them 153 /// separately. 154 SmallVector<LatticeValue, 4> resultLatticeValues; 155 156 /// The calls referencing this callable if this callable defines a symbol. 157 /// This removes the need to recompute symbol references during propagation. 158 /// Value based references are trivial to resolve, so they can be done 159 /// in-place. 160 SmallVector<Operation *, 4> symbolCalls; 161 }; 162 163 /// This class represents the solver for the SCCP analysis. This class acts as 164 /// the propagation engine for computing which values form constants. 165 class SCCPSolver { 166 public: 167 /// Initialize the solver with the given top-level operation. 168 SCCPSolver(Operation *op); 169 170 /// Run the solver until it converges. 171 void solve(); 172 173 /// Rewrite the given regions using the computing analysis. This replaces the 174 /// uses of all values that have been computed to be constant, and erases as 175 /// many newly dead operations. 176 void rewrite(MLIRContext *context, MutableArrayRef<Region> regions); 177 178 private: 179 /// Initialize the set of symbol defining callables that can have their 180 /// arguments and results tracked. 'op' is the top-level operation that SCCP 181 /// is operating on. 182 void initializeSymbolCallables(Operation *op); 183 184 /// Replace the given value with a constant if the corresponding lattice 185 /// represents a constant. Returns success if the value was replaced, failure 186 /// otherwise. 187 LogicalResult replaceWithConstant(OpBuilder &builder, OperationFolder &folder, 188 Value value); 189 190 /// Visit the users of the given IR that reside within executable blocks. 191 template <typename T> 192 void visitUsers(T &value) { 193 for (Operation *user : value.getUsers()) 194 if (isBlockExecutable(user->getBlock())) 195 visitOperation(user); 196 } 197 198 /// Visit the given operation and compute any necessary lattice state. 199 void visitOperation(Operation *op); 200 201 /// Visit the given call operation and compute any necessary lattice state. 202 void visitCallOperation(CallOpInterface op); 203 204 /// Visit the given callable operation and compute any necessary lattice 205 /// state. 206 void visitCallableOperation(Operation *op); 207 208 /// Visit the given operation, which defines regions, and compute any 209 /// necessary lattice state. This also resolves the lattice state of both the 210 /// operation results and any nested regions. 211 void visitRegionOperation(Operation *op, 212 ArrayRef<Attribute> constantOperands); 213 214 /// Visit the given set of region successors, computing any necessary lattice 215 /// state. The provided function returns the input operands to the region at 216 /// the given index. If the index is 'None', the input operands correspond to 217 /// the parent operation results. 218 void visitRegionSuccessors( 219 Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors, 220 function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion); 221 222 /// Visit the given terminator operation and compute any necessary lattice 223 /// state. 224 void visitTerminatorOperation(Operation *op, 225 ArrayRef<Attribute> constantOperands); 226 227 /// Visit the given terminator operation that exits a callable region. These 228 /// are terminators with no CFG successors. 229 void visitCallableTerminatorOperation(Operation *callable, 230 Operation *terminator); 231 232 /// Visit the given block and compute any necessary lattice state. 233 void visitBlock(Block *block); 234 235 /// Visit argument #'i' of the given block and compute any necessary lattice 236 /// state. 237 void visitBlockArgument(Block *block, int i); 238 239 /// Mark the entry block of the given region as executable. Returns false if 240 /// the block was already marked executable. If `markArgsOverdefined` is true, 241 /// the arguments of the entry block are also set to overdefined. 242 bool markEntryBlockExecutable(Region *region, bool markArgsOverdefined); 243 244 /// Mark the given block as executable. Returns false if the block was already 245 /// marked executable. 246 bool markBlockExecutable(Block *block); 247 248 /// Returns true if the given block is executable. 249 bool isBlockExecutable(Block *block) const; 250 251 /// Mark the edge between 'from' and 'to' as executable. 252 void markEdgeExecutable(Block *from, Block *to); 253 254 /// Return true if the edge between 'from' and 'to' is executable. 255 bool isEdgeExecutable(Block *from, Block *to) const; 256 257 /// Mark the given value as overdefined. This means that we cannot refine a 258 /// specific constant for this value. 259 void markOverdefined(Value value); 260 261 /// Mark all of the given values as overdefined. 262 template <typename ValuesT> 263 void markAllOverdefined(ValuesT values) { 264 for (auto value : values) 265 markOverdefined(value); 266 } 267 template <typename ValuesT> 268 void markAllOverdefined(Operation *op, ValuesT values) { 269 markAllOverdefined(values); 270 opWorklist.push_back(op); 271 } 272 template <typename ValuesT> 273 void markAllOverdefinedAndVisitUsers(ValuesT values) { 274 for (auto value : values) { 275 auto &lattice = latticeValues[value]; 276 if (!lattice.isOverdefined()) { 277 lattice.markOverdefined(); 278 visitUsers(value); 279 } 280 } 281 } 282 283 /// Returns true if the given value was marked as overdefined. 284 bool isOverdefined(Value value) const; 285 286 /// Merge in the given lattice 'from' into the lattice 'to'. 'owner' 287 /// corresponds to the parent operation of 'to'. 288 void meet(Operation *owner, LatticeValue &to, const LatticeValue &from); 289 290 /// The lattice for each SSA value. 291 DenseMap<Value, LatticeValue> latticeValues; 292 293 /// The set of blocks that are known to execute, or are intrinsically live. 294 SmallPtrSet<Block *, 16> executableBlocks; 295 296 /// The set of control flow edges that are known to execute. 297 DenseSet<std::pair<Block *, Block *>> executableEdges; 298 299 /// A worklist containing blocks that need to be processed. 300 SmallVector<Block *, 64> blockWorklist; 301 302 /// A worklist of operations that need to be processed. 303 SmallVector<Operation *, 64> opWorklist; 304 305 /// The callable operations that have their argument/result state tracked. 306 DenseMap<Operation *, CallableLatticeState> callableLatticeState; 307 308 /// A map between a call operation and the resolved symbol callable. This 309 /// avoids re-resolving symbol references during propagation. Value based 310 /// callables are trivial to resolve, so they can be done in-place. 311 DenseMap<Operation *, Operation *> callToSymbolCallable; 312 313 /// A symbol table used for O(1) symbol lookups during simplification. 314 SymbolTableCollection symbolTable; 315 }; 316 } // end anonymous namespace 317 318 SCCPSolver::SCCPSolver(Operation *op) { 319 /// Initialize the solver with the regions within this operation. 320 for (Region ®ion : op->getRegions()) { 321 // Mark the entry block as executable. The values passed to these regions 322 // are also invisible, so mark any arguments as overdefined. 323 markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true); 324 } 325 initializeSymbolCallables(op); 326 } 327 328 void SCCPSolver::solve() { 329 while (!blockWorklist.empty() || !opWorklist.empty()) { 330 // Process any operations in the op worklist. 331 while (!opWorklist.empty()) 332 visitUsers(*opWorklist.pop_back_val()); 333 334 // Process any blocks in the block worklist. 335 while (!blockWorklist.empty()) 336 visitBlock(blockWorklist.pop_back_val()); 337 } 338 } 339 340 void SCCPSolver::rewrite(MLIRContext *context, 341 MutableArrayRef<Region> initialRegions) { 342 SmallVector<Block *, 8> worklist; 343 auto addToWorklist = [&](MutableArrayRef<Region> regions) { 344 for (Region ®ion : regions) 345 for (Block &block : region) 346 if (isBlockExecutable(&block)) 347 worklist.push_back(&block); 348 }; 349 350 // An operation folder used to create and unique constants. 351 OperationFolder folder(context); 352 OpBuilder builder(context); 353 354 addToWorklist(initialRegions); 355 while (!worklist.empty()) { 356 Block *block = worklist.pop_back_val(); 357 358 // Replace any block arguments with constants. 359 builder.setInsertionPointToStart(block); 360 for (BlockArgument arg : block->getArguments()) 361 (void)replaceWithConstant(builder, folder, arg); 362 363 for (Operation &op : llvm::make_early_inc_range(*block)) { 364 builder.setInsertionPoint(&op); 365 366 // Replace any result with constants. 367 bool replacedAll = op.getNumResults() != 0; 368 for (Value res : op.getResults()) 369 replacedAll &= succeeded(replaceWithConstant(builder, folder, res)); 370 371 // If all of the results of the operation were replaced, try to erase 372 // the operation completely. 373 if (replacedAll && wouldOpBeTriviallyDead(&op)) { 374 assert(op.use_empty() && "expected all uses to be replaced"); 375 op.erase(); 376 continue; 377 } 378 379 // Add any the regions of this operation to the worklist. 380 addToWorklist(op.getRegions()); 381 } 382 } 383 } 384 385 void SCCPSolver::initializeSymbolCallables(Operation *op) { 386 // Initialize the set of symbol callables that can have their state tracked. 387 // This tracks which symbol callable operations we can propagate within and 388 // out of. 389 auto walkFn = [&](Operation *symTable, bool allUsesVisible) { 390 Region &symbolTableRegion = symTable->getRegion(0); 391 Block *symbolTableBlock = &symbolTableRegion.front(); 392 for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) { 393 // We won't be able to track external callables. 394 Region *callableRegion = callable.getCallableRegion(); 395 if (!callableRegion) 396 continue; 397 // We only care about symbol defining callables here. 398 auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation()); 399 if (!symbol) 400 continue; 401 callableLatticeState.try_emplace(callable, callableRegion, 402 callable.getCallableResults().size()); 403 404 // If not all of the uses of this symbol are visible, we can't track the 405 // state of the arguments. 406 if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) { 407 for (Region ®ion : callable->getRegions()) 408 markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true); 409 } 410 } 411 if (callableLatticeState.empty()) 412 return; 413 414 // After computing the valid callables, walk any symbol uses to check 415 // for non-call references. We won't be able to track the lattice state 416 // for arguments to these callables, as we can't guarantee that we can see 417 // all of its calls. 418 Optional<SymbolTable::UseRange> uses = 419 SymbolTable::getSymbolUses(&symbolTableRegion); 420 if (!uses) { 421 // If we couldn't gather the symbol uses, conservatively assume that 422 // we can't track information for any nested symbols. 423 op->walk([&](CallableOpInterface op) { callableLatticeState.erase(op); }); 424 return; 425 } 426 427 for (const SymbolTable::SymbolUse &use : *uses) { 428 // If the use is a call, track it to avoid the need to recompute the 429 // reference later. 430 if (auto callOp = dyn_cast<CallOpInterface>(use.getUser())) { 431 Operation *symCallable = callOp.resolveCallable(&symbolTable); 432 auto callableLatticeIt = callableLatticeState.find(symCallable); 433 if (callableLatticeIt != callableLatticeState.end()) { 434 callToSymbolCallable.try_emplace(callOp, symCallable); 435 436 // We only need to record the call in the lattice if it produces any 437 // values. 438 if (callOp->getNumResults()) 439 callableLatticeIt->second.addSymbolCall(callOp); 440 } 441 continue; 442 } 443 // This use isn't a call, so don't we know all of the callers. 444 auto *symbol = symbolTable.lookupSymbolIn(op, use.getSymbolRef()); 445 auto it = callableLatticeState.find(symbol); 446 if (it != callableLatticeState.end()) { 447 for (Region ®ion : it->first->getRegions()) 448 markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true); 449 } 450 } 451 }; 452 SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), 453 walkFn); 454 } 455 456 LogicalResult SCCPSolver::replaceWithConstant(OpBuilder &builder, 457 OperationFolder &folder, 458 Value value) { 459 auto it = latticeValues.find(value); 460 auto attr = it == latticeValues.end() ? nullptr : it->second.getConstant(); 461 if (!attr) 462 return failure(); 463 464 // Attempt to materialize a constant for the given value. 465 Dialect *dialect = it->second.getConstantDialect(); 466 Value constant = folder.getOrCreateConstant(builder, dialect, attr, 467 value.getType(), value.getLoc()); 468 if (!constant) 469 return failure(); 470 471 value.replaceAllUsesWith(constant); 472 latticeValues.erase(it); 473 return success(); 474 } 475 476 void SCCPSolver::visitOperation(Operation *op) { 477 // Collect all of the constant operands feeding into this operation. If any 478 // are not ready to be resolved, bail out and wait for them to resolve. 479 SmallVector<Attribute, 8> operandConstants; 480 operandConstants.reserve(op->getNumOperands()); 481 for (Value operand : op->getOperands()) { 482 // Make sure all of the operands are resolved first. 483 auto &operandLattice = latticeValues[operand]; 484 if (operandLattice.isUnknown()) 485 return; 486 operandConstants.push_back(operandLattice.getConstant()); 487 } 488 489 // If this is a terminator operation, process any control flow lattice state. 490 if (op->hasTrait<OpTrait::IsTerminator>()) 491 visitTerminatorOperation(op, operandConstants); 492 493 // Process call operations. The call visitor processes result values, so we 494 // can exit afterwards. 495 if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) 496 return visitCallOperation(call); 497 498 // Process callable operations. These are specially handled region operations 499 // that track dataflow via calls. 500 if (isa<CallableOpInterface>(op)) { 501 // If this callable has a tracked lattice state, it will be visited by calls 502 // that reference it instead. This way, we don't assume that it is 503 // executable unless there is a proper reference to it. 504 if (callableLatticeState.count(op)) 505 return; 506 return visitCallableOperation(op); 507 } 508 509 // Process region holding operations. The region visitor processes result 510 // values, so we can exit afterwards. 511 if (op->getNumRegions()) 512 return visitRegionOperation(op, operandConstants); 513 514 // If this op produces no results, it can't produce any constants. 515 if (op->getNumResults() == 0) 516 return; 517 518 // If all of the results of this operation are already overdefined, bail out 519 // early. 520 auto isOverdefinedFn = [&](Value value) { return isOverdefined(value); }; 521 if (llvm::all_of(op->getResults(), isOverdefinedFn)) 522 return; 523 524 // Save the original operands and attributes just in case the operation folds 525 // in-place. The constant passed in may not correspond to the real runtime 526 // value, so in-place updates are not allowed. 527 SmallVector<Value, 8> originalOperands(op->getOperands()); 528 DictionaryAttr originalAttrs = op->getAttrDictionary(); 529 530 // Simulate the result of folding this operation to a constant. If folding 531 // fails or was an in-place fold, mark the results as overdefined. 532 SmallVector<OpFoldResult, 8> foldResults; 533 foldResults.reserve(op->getNumResults()); 534 if (failed(op->fold(operandConstants, foldResults))) 535 return markAllOverdefined(op, op->getResults()); 536 537 // If the folding was in-place, mark the results as overdefined and reset the 538 // operation. We don't allow in-place folds as the desire here is for 539 // simulated execution, and not general folding. 540 if (foldResults.empty()) { 541 op->setOperands(originalOperands); 542 op->setAttrs(originalAttrs); 543 return markAllOverdefined(op, op->getResults()); 544 } 545 546 // Merge the fold results into the lattice for this operation. 547 assert(foldResults.size() == op->getNumResults() && "invalid result size"); 548 Dialect *opDialect = op->getDialect(); 549 for (unsigned i = 0, e = foldResults.size(); i != e; ++i) { 550 LatticeValue &resultLattice = latticeValues[op->getResult(i)]; 551 552 // Merge in the result of the fold, either a constant or a value. 553 OpFoldResult foldResult = foldResults[i]; 554 if (Attribute foldAttr = foldResult.dyn_cast<Attribute>()) 555 meet(op, resultLattice, LatticeValue(foldAttr, opDialect)); 556 else 557 meet(op, resultLattice, latticeValues[foldResult.get<Value>()]); 558 } 559 } 560 561 void SCCPSolver::visitCallableOperation(Operation *op) { 562 // Mark the regions as executable. If we aren't tracking lattice state for 563 // this callable, mark all of the region arguments as overdefined. 564 bool isTrackingLatticeState = callableLatticeState.count(op); 565 for (Region ®ion : op->getRegions()) 566 markEntryBlockExecutable(®ion, !isTrackingLatticeState); 567 568 // TODO: Add support for non-symbol callables when necessary. If the callable 569 // has non-call uses we would mark overdefined, otherwise allow for 570 // propagating the return values out. 571 markAllOverdefined(op, op->getResults()); 572 } 573 574 void SCCPSolver::visitCallOperation(CallOpInterface op) { 575 ResultRange callResults = op->getResults(); 576 577 // Resolve the callable operation for this call. 578 Operation *callableOp = nullptr; 579 if (Value callableValue = op.getCallableForCallee().dyn_cast<Value>()) 580 callableOp = callableValue.getDefiningOp(); 581 else 582 callableOp = callToSymbolCallable.lookup(op); 583 584 // The callable of this call can't be resolved, mark any results overdefined. 585 if (!callableOp) 586 return markAllOverdefined(op, callResults); 587 588 // If this callable is tracking state, merge the argument operands with the 589 // arguments of the callable. 590 auto callableLatticeIt = callableLatticeState.find(callableOp); 591 if (callableLatticeIt == callableLatticeState.end()) 592 return markAllOverdefined(op, callResults); 593 594 OperandRange callOperands = op.getArgOperands(); 595 auto callableArgs = callableLatticeIt->second.getCallableArguments(); 596 for (auto it : llvm::zip(callOperands, callableArgs)) { 597 BlockArgument callableArg = std::get<1>(it); 598 if (latticeValues[callableArg].meet(latticeValues[std::get<0>(it)])) 599 visitUsers(callableArg); 600 } 601 602 // Visit the callable. 603 visitCallableOperation(callableOp); 604 605 // Merge in the lattice state for the callable results as well. 606 auto callableResults = callableLatticeIt->second.getResultLatticeValues(); 607 for (auto it : llvm::zip(callResults, callableResults)) 608 meet(/*owner=*/op, /*to=*/latticeValues[std::get<0>(it)], 609 /*from=*/std::get<1>(it)); 610 } 611 612 void SCCPSolver::visitRegionOperation(Operation *op, 613 ArrayRef<Attribute> constantOperands) { 614 // Check to see if we can reason about the internal control flow of this 615 // region operation. 616 auto regionInterface = dyn_cast<RegionBranchOpInterface>(op); 617 if (!regionInterface) { 618 // If we can't, conservatively mark all regions as executable. 619 for (Region ®ion : op->getRegions()) 620 markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true); 621 622 // Don't try to simulate the results of a region operation as we can't 623 // guarantee that folding will be out-of-place. We don't allow in-place 624 // folds as the desire here is for simulated execution, and not general 625 // folding. 626 return markAllOverdefined(op, op->getResults()); 627 } 628 629 // Check to see which regions are executable. 630 SmallVector<RegionSuccessor, 1> successors; 631 regionInterface.getSuccessorRegions(/*index=*/llvm::None, constantOperands, 632 successors); 633 634 // If the interface identified that no region will be executed. Mark 635 // any results of this operation as overdefined, as we can't reason about 636 // them. 637 // TODO: If we had an interface to detect pass through operands, we could 638 // resolve some results based on the lattice state of the operands. We could 639 // also allow for the parent operation to have itself as a region successor. 640 if (successors.empty()) 641 return markAllOverdefined(op, op->getResults()); 642 return visitRegionSuccessors(op, successors, [&](Optional<unsigned> index) { 643 assert(index && "expected valid region index"); 644 return regionInterface.getSuccessorEntryOperands(*index); 645 }); 646 } 647 648 void SCCPSolver::visitRegionSuccessors( 649 Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors, 650 function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion) { 651 for (const RegionSuccessor &it : regionSuccessors) { 652 Region *region = it.getSuccessor(); 653 ValueRange succArgs = it.getSuccessorInputs(); 654 655 // Check to see if this is the parent operation. 656 if (!region) { 657 ResultRange results = parentOp->getResults(); 658 if (llvm::all_of(results, [&](Value res) { return isOverdefined(res); })) 659 continue; 660 661 // Mark the results outside of the input range as overdefined. 662 if (succArgs.size() != results.size()) { 663 opWorklist.push_back(parentOp); 664 if (succArgs.empty()) 665 return markAllOverdefined(results); 666 667 unsigned firstResIdx = succArgs[0].cast<OpResult>().getResultNumber(); 668 markAllOverdefined(results.take_front(firstResIdx)); 669 markAllOverdefined(results.drop_front(firstResIdx + succArgs.size())); 670 } 671 672 // Update the lattice for any operation results. 673 OperandRange operands = getInputsForRegion(/*index=*/llvm::None); 674 for (auto it : llvm::zip(succArgs, operands)) 675 meet(parentOp, latticeValues[std::get<0>(it)], 676 latticeValues[std::get<1>(it)]); 677 return; 678 } 679 assert(!region->empty() && "expected region to be non-empty"); 680 Block *entryBlock = ®ion->front(); 681 markBlockExecutable(entryBlock); 682 683 // If all of the arguments are already overdefined, the arguments have 684 // already been fully resolved. 685 auto arguments = entryBlock->getArguments(); 686 if (llvm::all_of(arguments, [&](Value arg) { return isOverdefined(arg); })) 687 continue; 688 689 // Mark any arguments that do not receive inputs as overdefined, we won't be 690 // able to discern if they are constant. 691 if (succArgs.size() != arguments.size()) { 692 if (succArgs.empty()) { 693 markAllOverdefined(arguments); 694 continue; 695 } 696 697 unsigned firstArgIdx = succArgs[0].cast<BlockArgument>().getArgNumber(); 698 markAllOverdefinedAndVisitUsers(arguments.take_front(firstArgIdx)); 699 markAllOverdefinedAndVisitUsers( 700 arguments.drop_front(firstArgIdx + succArgs.size())); 701 } 702 703 // Update the lattice for arguments that have inputs from the predecessor. 704 OperandRange succOperands = getInputsForRegion(region->getRegionNumber()); 705 for (auto it : llvm::zip(succArgs, succOperands)) { 706 LatticeValue &argLattice = latticeValues[std::get<0>(it)]; 707 if (argLattice.meet(latticeValues[std::get<1>(it)])) 708 visitUsers(std::get<0>(it)); 709 } 710 } 711 } 712 713 void SCCPSolver::visitTerminatorOperation( 714 Operation *op, ArrayRef<Attribute> constantOperands) { 715 // If this operation has no successors, we treat it as an exiting terminator. 716 if (op->getNumSuccessors() == 0) { 717 Region *parentRegion = op->getParentRegion(); 718 Operation *parentOp = parentRegion->getParentOp(); 719 720 // Check to see if this is a terminator for a callable region. 721 if (isa<CallableOpInterface>(parentOp)) 722 return visitCallableTerminatorOperation(parentOp, op); 723 724 // Otherwise, check to see if the parent tracks region control flow. 725 auto regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp); 726 if (!regionInterface || !isBlockExecutable(parentOp->getBlock())) 727 return; 728 729 // Query the set of successors from the current region. 730 SmallVector<RegionSuccessor, 1> regionSuccessors; 731 regionInterface.getSuccessorRegions(parentRegion->getRegionNumber(), 732 constantOperands, regionSuccessors); 733 if (regionSuccessors.empty()) 734 return; 735 736 // If this terminator is not "region-like", conservatively mark all of the 737 // successor values as overdefined. 738 if (!op->hasTrait<OpTrait::ReturnLike>()) { 739 for (auto &it : regionSuccessors) 740 markAllOverdefinedAndVisitUsers(it.getSuccessorInputs()); 741 return; 742 } 743 744 // Otherwise, propagate the operand lattice states to each of the 745 // successors. 746 OperandRange operands = op->getOperands(); 747 return visitRegionSuccessors(parentOp, regionSuccessors, 748 [&](Optional<unsigned>) { return operands; }); 749 } 750 751 // Try to resolve to a specific successor with the constant operands. 752 if (auto branch = dyn_cast<BranchOpInterface>(op)) { 753 if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) { 754 markEdgeExecutable(op->getBlock(), singleSucc); 755 return; 756 } 757 } 758 759 // Otherwise, conservatively treat all edges as executable. 760 Block *block = op->getBlock(); 761 for (Block *succ : op->getSuccessors()) 762 markEdgeExecutable(block, succ); 763 } 764 765 void SCCPSolver::visitCallableTerminatorOperation(Operation *callable, 766 Operation *terminator) { 767 // If there are no exiting values, we have nothing to track. 768 if (terminator->getNumOperands() == 0) 769 return; 770 771 // If this callable isn't tracking any lattice state there is nothing to do. 772 auto latticeIt = callableLatticeState.find(callable); 773 if (latticeIt == callableLatticeState.end()) 774 return; 775 assert(callable->getNumResults() == 0 && "expected symbol callable"); 776 777 // If this terminator is not "return-like", conservatively mark all of the 778 // call-site results as overdefined. 779 auto callableResultLattices = latticeIt->second.getResultLatticeValues(); 780 if (!terminator->hasTrait<OpTrait::ReturnLike>()) { 781 for (auto &it : callableResultLattices) 782 it.markOverdefined(); 783 for (Operation *call : latticeIt->second.getSymbolCalls()) 784 markAllOverdefined(call, call->getResults()); 785 return; 786 } 787 788 // Merge the terminator operands into the results. 789 bool anyChanged = false; 790 for (auto it : llvm::zip(terminator->getOperands(), callableResultLattices)) 791 anyChanged |= std::get<1>(it).meet(latticeValues[std::get<0>(it)]); 792 if (!anyChanged) 793 return; 794 795 // If any of the result lattices changed, update the callers. 796 for (Operation *call : latticeIt->second.getSymbolCalls()) 797 for (auto it : llvm::zip(call->getResults(), callableResultLattices)) 798 meet(call, latticeValues[std::get<0>(it)], std::get<1>(it)); 799 } 800 801 void SCCPSolver::visitBlock(Block *block) { 802 // If the block is not the entry block we need to compute the lattice state 803 // for the block arguments. Entry block argument lattices are computed 804 // elsewhere, such as when visiting the parent operation. 805 if (!block->isEntryBlock()) { 806 for (int i : llvm::seq<int>(0, block->getNumArguments())) 807 visitBlockArgument(block, i); 808 } 809 810 // Visit all of the operations within the block. 811 for (Operation &op : *block) 812 visitOperation(&op); 813 } 814 815 void SCCPSolver::visitBlockArgument(Block *block, int i) { 816 BlockArgument arg = block->getArgument(i); 817 LatticeValue &argLattice = latticeValues[arg]; 818 if (argLattice.isOverdefined()) 819 return; 820 821 bool updatedLattice = false; 822 for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { 823 Block *pred = *it; 824 825 // We only care about this predecessor if it is going to execute. 826 if (!isEdgeExecutable(pred, block)) 827 continue; 828 829 // Try to get the operand forwarded by the predecessor. If we can't reason 830 // about the terminator of the predecessor, mark overdefined. 831 Optional<OperandRange> branchOperands; 832 if (auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator())) 833 branchOperands = branch.getSuccessorOperands(it.getSuccessorIndex()); 834 if (!branchOperands) { 835 updatedLattice = true; 836 argLattice.markOverdefined(); 837 break; 838 } 839 840 // If the operand hasn't been resolved, it is unknown which can merge with 841 // anything. 842 auto operandLattice = latticeValues.find((*branchOperands)[i]); 843 if (operandLattice == latticeValues.end()) 844 continue; 845 846 // Otherwise, meet the two lattice values. 847 updatedLattice |= argLattice.meet(operandLattice->second); 848 if (argLattice.isOverdefined()) 849 break; 850 } 851 852 // If the lattice was updated, visit any executable users of the argument. 853 if (updatedLattice) 854 visitUsers(arg); 855 } 856 857 bool SCCPSolver::markEntryBlockExecutable(Region *region, 858 bool markArgsOverdefined) { 859 if (!region->empty()) { 860 if (markArgsOverdefined) 861 markAllOverdefined(region->front().getArguments()); 862 return markBlockExecutable(®ion->front()); 863 } 864 return false; 865 } 866 867 bool SCCPSolver::markBlockExecutable(Block *block) { 868 bool marked = executableBlocks.insert(block).second; 869 if (marked) 870 blockWorklist.push_back(block); 871 return marked; 872 } 873 874 bool SCCPSolver::isBlockExecutable(Block *block) const { 875 return executableBlocks.count(block); 876 } 877 878 void SCCPSolver::markEdgeExecutable(Block *from, Block *to) { 879 if (!executableEdges.insert(std::make_pair(from, to)).second) 880 return; 881 // Mark the destination as executable, and reprocess its arguments if it was 882 // already executable. 883 if (!markBlockExecutable(to)) { 884 for (int i : llvm::seq<int>(0, to->getNumArguments())) 885 visitBlockArgument(to, i); 886 } 887 } 888 889 bool SCCPSolver::isEdgeExecutable(Block *from, Block *to) const { 890 return executableEdges.count(std::make_pair(from, to)); 891 } 892 893 void SCCPSolver::markOverdefined(Value value) { 894 latticeValues[value].markOverdefined(); 895 } 896 897 bool SCCPSolver::isOverdefined(Value value) const { 898 auto it = latticeValues.find(value); 899 return it != latticeValues.end() && it->second.isOverdefined(); 900 } 901 902 void SCCPSolver::meet(Operation *owner, LatticeValue &to, 903 const LatticeValue &from) { 904 if (to.meet(from)) 905 opWorklist.push_back(owner); 906 } 907 908 //===----------------------------------------------------------------------===// 909 // SCCP Pass 910 //===----------------------------------------------------------------------===// 911 912 namespace { 913 struct SCCP : public SCCPBase<SCCP> { 914 void runOnOperation() override; 915 }; 916 } // end anonymous namespace 917 918 void SCCP::runOnOperation() { 919 Operation *op = getOperation(); 920 921 // Solve for SCCP constraints within nested regions. 922 SCCPSolver solver(op); 923 solver.solve(); 924 925 // Cleanup any operations using the solver analysis. 926 solver.rewrite(&getContext(), op->getRegions()); 927 } 928 929 std::unique_ptr<Pass> mlir::createSCCPPass() { 930 return std::make_unique<SCCP>(); 931 } 932