1 //===- BufferUtils.cpp - buffer transformation utilities ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities for buffer optimization passes. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" 14 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/Interfaces/ControlFlowInterfaces.h" 19 #include "mlir/Interfaces/LoopLikeInterface.h" 20 #include "mlir/Pass/Pass.h" 21 #include "llvm/ADT/SetOperations.h" 22 23 using namespace mlir; 24 using namespace mlir::bufferization; 25 26 //===----------------------------------------------------------------------===// 27 // BufferPlacementAllocs 28 //===----------------------------------------------------------------------===// 29 30 /// Get the start operation to place the given alloc value withing the 31 // specified placement block. 32 Operation *BufferPlacementAllocs::getStartOperation(Value allocValue, 33 Block *placementBlock, 34 const Liveness &liveness) { 35 // We have to ensure that we place the alloc before its first use in this 36 // block. 37 const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(placementBlock); 38 Operation *startOperation = livenessInfo.getStartOperation(allocValue); 39 // Check whether the start operation lies in the desired placement block. 40 // If not, we will use the terminator as this is the last operation in 41 // this block. 42 if (startOperation->getBlock() != placementBlock) { 43 Operation *opInPlacementBlock = 44 placementBlock->findAncestorOpInBlock(*startOperation); 45 startOperation = opInPlacementBlock ? opInPlacementBlock 46 : placementBlock->getTerminator(); 47 } 48 49 return startOperation; 50 } 51 52 /// Initializes the internal list by discovering all supported allocation 53 /// nodes. 54 BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); } 55 56 /// Searches for and registers all supported allocation entries. 57 void BufferPlacementAllocs::build(Operation *op) { 58 op->walk([&](MemoryEffectOpInterface opInterface) { 59 // Try to find a single allocation result. 60 SmallVector<MemoryEffects::EffectInstance, 2> effects; 61 opInterface.getEffects(effects); 62 63 SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects; 64 llvm::copy_if( 65 effects, std::back_inserter(allocateResultEffects), 66 [=](MemoryEffects::EffectInstance &it) { 67 Value value = it.getValue(); 68 return isa<MemoryEffects::Allocate>(it.getEffect()) && value && 69 value.isa<OpResult>() && 70 it.getResource() != 71 SideEffects::AutomaticAllocationScopeResource::get(); 72 }); 73 // If there is one result only, we will be able to move the allocation and 74 // (possibly existing) deallocation ops. 75 if (allocateResultEffects.size() != 1) 76 return; 77 // Get allocation result. 78 Value allocValue = allocateResultEffects[0].getValue(); 79 // Find the associated dealloc value and register the allocation entry. 80 llvm::Optional<Operation *> dealloc = findDealloc(allocValue); 81 // If the allocation has > 1 dealloc associated with it, skip handling it. 82 if (!dealloc.hasValue()) 83 return; 84 allocs.push_back(std::make_tuple(allocValue, *dealloc)); 85 }); 86 } 87 88 //===----------------------------------------------------------------------===// 89 // BufferPlacementTransformationBase 90 //===----------------------------------------------------------------------===// 91 92 /// Constructs a new transformation base using the given root operation. 93 BufferPlacementTransformationBase::BufferPlacementTransformationBase( 94 Operation *op) 95 : aliases(op), allocs(op), liveness(op) {} 96 97 /// Returns true if the given operation represents a loop by testing whether it 98 /// implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`. In 99 /// the case of a `RegionBranchOpInterface`, it checks all region-based control- 100 /// flow edges for cycles. 101 bool BufferPlacementTransformationBase::isLoop(Operation *op) { 102 // If the operation implements the `LoopLikeOpInterface` it can be considered 103 // a loop. 104 if (isa<LoopLikeOpInterface>(op)) 105 return true; 106 107 // If the operation does not implement the `RegionBranchOpInterface`, it is 108 // (currently) not possible to detect a loop. 109 RegionBranchOpInterface regionInterface; 110 if (!(regionInterface = dyn_cast<RegionBranchOpInterface>(op))) 111 return false; 112 113 // Recurses into a region using the current region interface to find potential 114 // cycles. 115 SmallPtrSet<Region *, 4> visitedRegions; 116 std::function<bool(Region *)> recurse = [&](Region *current) { 117 if (!current) 118 return false; 119 // If we have found a back edge, the parent operation induces a loop. 120 if (!visitedRegions.insert(current).second) 121 return true; 122 // Recurses into all region successors. 123 SmallVector<RegionSuccessor, 2> successors; 124 regionInterface.getSuccessorRegions(current->getRegionNumber(), successors); 125 for (RegionSuccessor ®ionEntry : successors) 126 if (recurse(regionEntry.getSuccessor())) 127 return true; 128 return false; 129 }; 130 131 // Start with all entry regions and test whether they induce a loop. 132 SmallVector<RegionSuccessor, 2> successorRegions; 133 regionInterface.getSuccessorRegions(/*index=*/llvm::None, successorRegions); 134 for (RegionSuccessor ®ionEntry : successorRegions) { 135 if (recurse(regionEntry.getSuccessor())) 136 return true; 137 visitedRegions.clear(); 138 } 139 140 return false; 141 } 142 143 //===----------------------------------------------------------------------===// 144 // BufferPlacementTransformationBase 145 //===----------------------------------------------------------------------===// 146 147 FailureOr<memref::GlobalOp> 148 bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment) { 149 auto type = constantOp.getType().cast<RankedTensorType>(); 150 auto moduleOp = constantOp->getParentOfType<ModuleOp>(); 151 if (!moduleOp) 152 return failure(); 153 154 // If we already have a global for this constant value, no need to do 155 // anything else. 156 for (Operation &op : moduleOp.getRegion().getOps()) { 157 auto globalOp = dyn_cast<memref::GlobalOp>(&op); 158 if (!globalOp) 159 continue; 160 if (!globalOp.initial_value().hasValue()) 161 continue; 162 uint64_t opAlignment = 163 globalOp.alignment().hasValue() ? globalOp.alignment().getValue() : 0; 164 Attribute initialValue = globalOp.initial_value().getValue(); 165 if (opAlignment == alignment && initialValue == constantOp.getValue()) 166 return globalOp; 167 } 168 169 // Create a builder without an insertion point. We will insert using the 170 // symbol table to guarantee unique names. 171 OpBuilder globalBuilder(moduleOp.getContext()); 172 SymbolTable symbolTable(moduleOp); 173 174 // Create a pretty name. 175 SmallString<64> buf; 176 llvm::raw_svector_ostream os(buf); 177 interleave(type.getShape(), os, "x"); 178 os << "x" << type.getElementType(); 179 180 // Add an optional alignment to the global memref. 181 IntegerAttr memrefAlignment = 182 alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment) 183 : IntegerAttr(); 184 185 BufferizeTypeConverter typeConverter; 186 auto global = globalBuilder.create<memref::GlobalOp>( 187 constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), 188 /*sym_visibility=*/globalBuilder.getStringAttr("private"), 189 /*type=*/typeConverter.convertType(type).cast<MemRefType>(), 190 /*initial_value=*/constantOp.getValue().cast<ElementsAttr>(), 191 /*constant=*/true, 192 /*alignment=*/memrefAlignment); 193 symbolTable.insert(global); 194 // The symbol table inserts at the end of the module, but globals are a bit 195 // nicer if they are at the beginning. 196 global->moveBefore(&moduleOp.front()); 197 return global; 198 } 199