1 //===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation pass performs a sparse conditional constant propagation 10 // in MLIR. It identifies values known to be constant, propagates that 11 // information throughout the IR, and replaces them. This is done with an 12 // optimisitic dataflow analysis that assumes that all values are constant until 13 // proven otherwise. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "PassDetail.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/Dialect.h" 20 #include "mlir/Interfaces/ControlFlowInterfaces.h" 21 #include "mlir/Interfaces/SideEffects.h" 22 #include "mlir/Pass/Pass.h" 23 #include "mlir/Transforms/FoldUtils.h" 24 #include "mlir/Transforms/Passes.h" 25 26 using namespace mlir; 27 28 namespace { 29 /// This class represents a single lattice value. A lattive value corresponds to 30 /// the various different states that a value in the SCCP dataflow anaylsis can 31 /// take. See 'Kind' below for more details on the different states a value can 32 /// take. 33 class LatticeValue { 34 enum Kind { 35 /// A value with a yet to be determined value. This state may be changed to 36 /// anything. 37 Unknown, 38 39 /// A value that is known to be a constant. This state may be changed to 40 /// overdefined. 41 Constant, 42 43 /// A value that cannot statically be determined to be a constant. This 44 /// state cannot be changed. 45 Overdefined 46 }; 47 48 public: 49 /// Initialize a lattice value with "Unknown". 50 LatticeValue() 51 : constantAndTag(nullptr, Kind::Unknown), constantDialect(nullptr) {} 52 /// Initialize a lattice value with a constant. 53 LatticeValue(Attribute attr, Dialect *dialect) 54 : constantAndTag(attr, Kind::Constant), constantDialect(dialect) {} 55 56 /// Returns true if this lattice value is unknown. 57 bool isUnknown() const { return constantAndTag.getInt() == Kind::Unknown; } 58 59 /// Mark the lattice value as overdefined. 60 void markOverdefined() { 61 constantAndTag.setPointerAndInt(nullptr, Kind::Overdefined); 62 constantDialect = nullptr; 63 } 64 65 /// Returns true if the lattice is overdefined. 66 bool isOverdefined() const { 67 return constantAndTag.getInt() == Kind::Overdefined; 68 } 69 70 /// Mark the lattice value as constant. 71 void markConstant(Attribute value, Dialect *dialect) { 72 constantAndTag.setPointerAndInt(value, Kind::Constant); 73 constantDialect = dialect; 74 } 75 76 /// If this lattice is constant, return the constant. Returns nullptr 77 /// otherwise. 78 Attribute getConstant() const { return constantAndTag.getPointer(); } 79 80 /// If this lattice is constant, return the dialect to use when materializing 81 /// the constant. 82 Dialect *getConstantDialect() const { 83 assert(getConstant() && "expected valid constant"); 84 return constantDialect; 85 } 86 87 /// Merge in the value of the 'rhs' lattice into this one. Returns true if the 88 /// lattice value changed. 89 bool meet(const LatticeValue &rhs) { 90 // If we are already overdefined, or rhs is unknown, there is nothing to do. 91 if (isOverdefined() || rhs.isUnknown()) 92 return false; 93 // If we are unknown, just take the value of rhs. 94 if (isUnknown()) { 95 constantAndTag = rhs.constantAndTag; 96 constantDialect = rhs.constantDialect; 97 return true; 98 } 99 100 // Otherwise, if this value doesn't match rhs go straight to overdefined. 101 if (constantAndTag != rhs.constantAndTag) { 102 markOverdefined(); 103 return true; 104 } 105 return false; 106 } 107 108 private: 109 /// The attribute value if this is a constant and the tag for the element 110 /// kind. 111 llvm::PointerIntPair<Attribute, 2, Kind> constantAndTag; 112 113 /// The dialect the constant originated from. This is only valid if the 114 /// lattice is a constant. This is not used as part of the key, and is only 115 /// needed to materialize the held constant if necessary. 116 Dialect *constantDialect; 117 }; 118 119 /// This class represents the solver for the SCCP analysis. This class acts as 120 /// the propagation engine for computing which values form constants. 121 class SCCPSolver { 122 public: 123 /// Initialize the solver with a given set of regions. 124 SCCPSolver(MutableArrayRef<Region> regions); 125 126 /// Run the solver until it converges. 127 void solve(); 128 129 /// Rewrite the given regions using the computing analysis. This replaces the 130 /// uses of all values that have been computed to be constant, and erases as 131 /// many newly dead operations. 132 void rewrite(MLIRContext *context, MutableArrayRef<Region> regions); 133 134 private: 135 /// Replace the given value with a constant if the corresponding lattice 136 /// represents a constant. Returns success if the value was replaced, failure 137 /// otherwise. 138 LogicalResult replaceWithConstant(OpBuilder &builder, OperationFolder &folder, 139 Value value); 140 141 /// Visit the given operation and compute any necessary lattice state. 142 void visitOperation(Operation *op); 143 144 /// Visit the given operation, which defines regions, and compute any 145 /// necessary lattice state. This also resolves the lattice state of both the 146 /// operation results and any nested regions. 147 void visitRegionOperation(Operation *op); 148 149 /// Visit the given terminator operation and compute any necessary lattice 150 /// state. 151 void visitTerminatorOperation(Operation *op, 152 ArrayRef<Attribute> constantOperands); 153 154 /// Visit the given block and compute any necessary lattice state. 155 void visitBlock(Block *block); 156 157 /// Visit argument #'i' of the given block and compute any necessary lattice 158 /// state. 159 void visitBlockArgument(Block *block, int i); 160 161 /// Mark the given block as executable. Returns false if the block was already 162 /// marked executable. 163 bool markBlockExecutable(Block *block); 164 165 /// Returns true if the given block is executable. 166 bool isBlockExecutable(Block *block) const; 167 168 /// Mark the edge between 'from' and 'to' as executable. 169 void markEdgeExecutable(Block *from, Block *to); 170 171 /// Return true if the edge between 'from' and 'to' is executable. 172 bool isEdgeExecutable(Block *from, Block *to) const; 173 174 /// Mark the given value as overdefined. This means that we cannot refine a 175 /// specific constant for this value. 176 void markOverdefined(Value value); 177 178 /// Mark all of the given values as overdefined. 179 template <typename ValuesT> 180 void markAllOverdefined(ValuesT values) { 181 for (auto value : values) 182 markOverdefined(value); 183 } 184 template <typename ValuesT> 185 void markAllOverdefined(Operation *op, ValuesT values) { 186 markAllOverdefined(values); 187 opWorklist.push_back(op); 188 } 189 190 /// Returns true if the given value was marked as overdefined. 191 bool isOverdefined(Value value) const; 192 193 /// Merge in the given lattice 'from' into the lattice 'to'. 'owner' 194 /// corresponds to the parent operation of 'to'. 195 void meet(Operation *owner, LatticeValue &to, const LatticeValue &from); 196 197 /// The lattice for each SSA value. 198 DenseMap<Value, LatticeValue> latticeValues; 199 200 /// The set of blocks that are known to execute, or are intrinsically live. 201 SmallPtrSet<Block *, 16> executableBlocks; 202 203 /// The set of control flow edges that are known to execute. 204 DenseSet<std::pair<Block *, Block *>> executableEdges; 205 206 /// A worklist containing blocks that need to be processed. 207 SmallVector<Block *, 64> blockWorklist; 208 209 /// A worklist of operations that need to be processed. 210 SmallVector<Operation *, 64> opWorklist; 211 }; 212 } // end anonymous namespace 213 214 SCCPSolver::SCCPSolver(MutableArrayRef<Region> regions) { 215 for (Region ®ion : regions) { 216 if (region.empty()) 217 continue; 218 Block *entryBlock = ®ion.front(); 219 220 // Mark the entry block as executable. 221 markBlockExecutable(entryBlock); 222 223 // The values passed to these regions are invisible, so mark any arguments 224 // as overdefined. 225 markAllOverdefined(entryBlock->getArguments()); 226 } 227 } 228 229 void SCCPSolver::solve() { 230 while (!blockWorklist.empty() || !opWorklist.empty()) { 231 // Process any operations in the op worklist. 232 while (!opWorklist.empty()) { 233 Operation *op = opWorklist.pop_back_val(); 234 235 // Visit all of the live users to propagate changes to this operation. 236 for (Operation *user : op->getUsers()) { 237 if (isBlockExecutable(user->getBlock())) 238 visitOperation(user); 239 } 240 } 241 242 // Process any blocks in the block worklist. 243 while (!blockWorklist.empty()) 244 visitBlock(blockWorklist.pop_back_val()); 245 } 246 } 247 248 void SCCPSolver::rewrite(MLIRContext *context, 249 MutableArrayRef<Region> initialRegions) { 250 SmallVector<Block *, 8> worklist; 251 auto addToWorklist = [&](MutableArrayRef<Region> regions) { 252 for (Region ®ion : regions) 253 for (Block &block : region) 254 if (isBlockExecutable(&block)) 255 worklist.push_back(&block); 256 }; 257 258 // An operation folder used to create and unique constants. 259 OperationFolder folder(context); 260 OpBuilder builder(context); 261 262 addToWorklist(initialRegions); 263 while (!worklist.empty()) { 264 Block *block = worklist.pop_back_val(); 265 266 // Replace any block arguments with constants. 267 builder.setInsertionPointToStart(block); 268 for (BlockArgument arg : block->getArguments()) 269 replaceWithConstant(builder, folder, arg); 270 271 for (Operation &op : llvm::make_early_inc_range(*block)) { 272 builder.setInsertionPoint(&op); 273 274 // Replace any result with constants. 275 bool replacedAll = op.getNumResults() != 0; 276 for (Value res : op.getResults()) 277 replacedAll &= succeeded(replaceWithConstant(builder, folder, res)); 278 279 // If all of the results of the operation were replaced, try to erase 280 // the operation completely. 281 if (replacedAll && wouldOpBeTriviallyDead(&op)) { 282 assert(op.use_empty() && "expected all uses to be replaced"); 283 op.erase(); 284 continue; 285 } 286 287 // Add any the regions of this operation to the worklist. 288 addToWorklist(op.getRegions()); 289 } 290 } 291 } 292 293 LogicalResult SCCPSolver::replaceWithConstant(OpBuilder &builder, 294 OperationFolder &folder, 295 Value value) { 296 auto it = latticeValues.find(value); 297 auto attr = it == latticeValues.end() ? nullptr : it->second.getConstant(); 298 if (!attr) 299 return failure(); 300 301 // Attempt to materialize a constant for the given value. 302 Dialect *dialect = it->second.getConstantDialect(); 303 Value constant = folder.getOrCreateConstant(builder, dialect, attr, 304 value.getType(), value.getLoc()); 305 if (!constant) 306 return failure(); 307 308 value.replaceAllUsesWith(constant); 309 latticeValues.erase(it); 310 return success(); 311 } 312 313 void SCCPSolver::visitOperation(Operation *op) { 314 // Collect all of the constant operands feeding into this operation. If any 315 // are not ready to be resolved, bail out and wait for them to resolve. 316 SmallVector<Attribute, 8> operandConstants; 317 operandConstants.reserve(op->getNumOperands()); 318 for (Value operand : op->getOperands()) { 319 // Make sure all of the operands are resolved first. 320 auto &operandLattice = latticeValues[operand]; 321 if (operandLattice.isUnknown()) 322 return; 323 operandConstants.push_back(operandLattice.getConstant()); 324 } 325 326 // If this is a terminator operation, process any control flow lattice state. 327 if (op->isKnownTerminator()) 328 visitTerminatorOperation(op, operandConstants); 329 330 // Process region holding operations. The region visitor processes result 331 // values, so we can exit afterwards. 332 if (op->getNumRegions()) 333 return visitRegionOperation(op); 334 335 // If this op produces no results, it can't produce any constants. 336 if (op->getNumResults() == 0) 337 return; 338 339 // If all of the results of this operation are already overdefined, bail out 340 // early. 341 auto isOverdefinedFn = [&](Value value) { return isOverdefined(value); }; 342 if (llvm::all_of(op->getResults(), isOverdefinedFn)) 343 return; 344 345 // Save the original operands and attributes just in case the operation folds 346 // in-place. The constant passed in may not correspond to the real runtime 347 // value, so in-place updates are not allowed. 348 SmallVector<Value, 8> originalOperands(op->getOperands()); 349 NamedAttributeList originalAttrs = op->getAttrList(); 350 351 // Simulate the result of folding this operation to a constant. If folding 352 // fails or was an in-place fold, mark the results as overdefined. 353 SmallVector<OpFoldResult, 8> foldResults; 354 foldResults.reserve(op->getNumResults()); 355 if (failed(op->fold(operandConstants, foldResults))) 356 return markAllOverdefined(op, op->getResults()); 357 358 // If the folding was in-place, mark the results as overdefined and reset the 359 // operation. We don't allow in-place folds as the desire here is for 360 // simulated execution, and not general folding. 361 if (foldResults.empty()) { 362 op->setOperands(originalOperands); 363 op->setAttrs(originalAttrs); 364 return markAllOverdefined(op, op->getResults()); 365 } 366 367 // Merge the fold results into the lattice for this operation. 368 assert(foldResults.size() == op->getNumResults() && "invalid result size"); 369 Dialect *opDialect = op->getDialect(); 370 for (unsigned i = 0, e = foldResults.size(); i != e; ++i) { 371 LatticeValue &resultLattice = latticeValues[op->getResult(i)]; 372 373 // Merge in the result of the fold, either a constant or a value. 374 OpFoldResult foldResult = foldResults[i]; 375 if (Attribute foldAttr = foldResult.dyn_cast<Attribute>()) 376 meet(op, resultLattice, LatticeValue(foldAttr, opDialect)); 377 else 378 meet(op, resultLattice, latticeValues[foldResult.get<Value>()]); 379 } 380 } 381 382 void SCCPSolver::visitRegionOperation(Operation *op) { 383 for (Region ®ion : op->getRegions()) { 384 if (region.empty()) 385 continue; 386 Block *entryBlock = ®ion.front(); 387 markBlockExecutable(entryBlock); 388 markAllOverdefined(entryBlock->getArguments()); 389 } 390 391 // Don't try to simulate the results of a region operation as we can't 392 // guarantee that folding will be out-of-place. We don't allow in-place folds 393 // as the desire here is for simulated execution, and not general folding. 394 return markAllOverdefined(op, op->getResults()); 395 } 396 397 void SCCPSolver::visitTerminatorOperation( 398 Operation *op, ArrayRef<Attribute> constantOperands) { 399 if (op->getNumSuccessors() == 0) 400 return; 401 402 // Try to resolve to a specific successor with the constant operands. 403 if (auto branch = dyn_cast<BranchOpInterface>(op)) { 404 if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) { 405 markEdgeExecutable(op->getBlock(), singleSucc); 406 return; 407 } 408 } 409 410 // Otherwise, conservatively treat all edges as executable. 411 Block *block = op->getBlock(); 412 for (Block *succ : op->getSuccessors()) 413 markEdgeExecutable(block, succ); 414 } 415 416 void SCCPSolver::visitBlock(Block *block) { 417 // If the block is not the entry block we need to compute the lattice state 418 // for the block arguments. Entry block argument lattices are computed 419 // elsewhere, such as when visiting the parent operation. 420 if (!block->isEntryBlock()) { 421 for (int i : llvm::seq<int>(0, block->getNumArguments())) 422 visitBlockArgument(block, i); 423 } 424 425 // Visit all of the operations within the block. 426 for (Operation &op : *block) 427 visitOperation(&op); 428 } 429 430 void SCCPSolver::visitBlockArgument(Block *block, int i) { 431 BlockArgument arg = block->getArgument(i); 432 LatticeValue &argLattice = latticeValues[arg]; 433 if (argLattice.isOverdefined()) 434 return; 435 436 bool updatedLattice = false; 437 for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { 438 Block *pred = *it; 439 440 // We only care about this predecessor if it is going to execute. 441 if (!isEdgeExecutable(pred, block)) 442 continue; 443 444 // Try to get the operand forwarded by the predecessor. If we can't reason 445 // about the terminator of the predecessor, mark overdefined. 446 Optional<OperandRange> branchOperands; 447 if (auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator())) 448 branchOperands = branch.getSuccessorOperands(it.getSuccessorIndex()); 449 if (!branchOperands) { 450 updatedLattice = true; 451 argLattice.markOverdefined(); 452 break; 453 } 454 455 // If the operand hasn't been resolved, it is unknown which can merge with 456 // anything. 457 auto operandLattice = latticeValues.find((*branchOperands)[i]); 458 if (operandLattice == latticeValues.end()) 459 continue; 460 461 // Otherwise, meet the two lattice values. 462 updatedLattice |= argLattice.meet(operandLattice->second); 463 if (argLattice.isOverdefined()) 464 break; 465 } 466 467 // If the lattice was updated, visit any executable users of the argument. 468 if (updatedLattice) { 469 for (Operation *user : arg.getUsers()) 470 if (isBlockExecutable(user->getBlock())) 471 visitOperation(user); 472 } 473 } 474 475 bool SCCPSolver::markBlockExecutable(Block *block) { 476 bool marked = executableBlocks.insert(block).second; 477 if (marked) 478 blockWorklist.push_back(block); 479 return marked; 480 } 481 482 bool SCCPSolver::isBlockExecutable(Block *block) const { 483 return executableBlocks.count(block); 484 } 485 486 void SCCPSolver::markEdgeExecutable(Block *from, Block *to) { 487 if (!executableEdges.insert(std::make_pair(from, to)).second) 488 return; 489 // Mark the destination as executable, and reprocess its arguments if it was 490 // already executable. 491 if (!markBlockExecutable(to)) { 492 for (int i : llvm::seq<int>(0, to->getNumArguments())) 493 visitBlockArgument(to, i); 494 } 495 } 496 497 bool SCCPSolver::isEdgeExecutable(Block *from, Block *to) const { 498 return executableEdges.count(std::make_pair(from, to)); 499 } 500 501 void SCCPSolver::markOverdefined(Value value) { 502 latticeValues[value].markOverdefined(); 503 } 504 505 bool SCCPSolver::isOverdefined(Value value) const { 506 auto it = latticeValues.find(value); 507 return it != latticeValues.end() && it->second.isOverdefined(); 508 } 509 510 void SCCPSolver::meet(Operation *owner, LatticeValue &to, 511 const LatticeValue &from) { 512 if (to.meet(from)) 513 opWorklist.push_back(owner); 514 } 515 516 //===----------------------------------------------------------------------===// 517 // SCCP Pass 518 //===----------------------------------------------------------------------===// 519 520 namespace { 521 struct SCCP : public SCCPBase<SCCP> { 522 void runOnOperation() override; 523 }; 524 } // end anonymous namespace 525 526 void SCCP::runOnOperation() { 527 Operation *op = getOperation(); 528 529 // Solve for SCCP constraints within nested regions. 530 SCCPSolver solver(op->getRegions()); 531 solver.solve(); 532 533 // Cleanup any operations using the solver analysis. 534 solver.rewrite(&getContext(), op->getRegions()); 535 } 536 537 std::unique_ptr<Pass> mlir::createSCCPPass() { 538 return std::make_unique<SCCP>(); 539 } 540