1 //===- BufferUtils.cpp - buffer transformation utilities ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities for buffer optimization passes. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" 14 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/Interfaces/ControlFlowInterfaces.h" 19 #include "mlir/Interfaces/LoopLikeInterface.h" 20 #include "mlir/Pass/Pass.h" 21 #include "llvm/ADT/SetOperations.h" 22 23 using namespace mlir; 24 using namespace mlir::bufferization; 25 26 //===----------------------------------------------------------------------===// 27 // BufferPlacementAllocs 28 //===----------------------------------------------------------------------===// 29 30 /// Get the start operation to place the given alloc value withing the 31 // specified placement block. 32 Operation *BufferPlacementAllocs::getStartOperation(Value allocValue, 33 Block *placementBlock, 34 const Liveness &liveness) { 35 // We have to ensure that we place the alloc before its first use in this 36 // block. 37 const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(placementBlock); 38 Operation *startOperation = livenessInfo.getStartOperation(allocValue); 39 // Check whether the start operation lies in the desired placement block. 40 // If not, we will use the terminator as this is the last operation in 41 // this block. 42 if (startOperation->getBlock() != placementBlock) { 43 Operation *opInPlacementBlock = 44 placementBlock->findAncestorOpInBlock(*startOperation); 45 startOperation = opInPlacementBlock ? opInPlacementBlock 46 : placementBlock->getTerminator(); 47 } 48 49 return startOperation; 50 } 51 52 /// Initializes the internal list by discovering all supported allocation 53 /// nodes. 54 BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); } 55 56 /// Searches for and registers all supported allocation entries. 57 void BufferPlacementAllocs::build(Operation *op) { 58 op->walk([&](MemoryEffectOpInterface opInterface) { 59 // Try to find a single allocation result. 60 SmallVector<MemoryEffects::EffectInstance, 2> effects; 61 opInterface.getEffects(effects); 62 63 SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects; 64 llvm::copy_if( 65 effects, std::back_inserter(allocateResultEffects), 66 [=](MemoryEffects::EffectInstance &it) { 67 Value value = it.getValue(); 68 return isa<MemoryEffects::Allocate>(it.getEffect()) && value && 69 value.isa<OpResult>() && 70 it.getResource() != 71 SideEffects::AutomaticAllocationScopeResource::get(); 72 }); 73 // If there is one result only, we will be able to move the allocation and 74 // (possibly existing) deallocation ops. 75 if (allocateResultEffects.size() != 1) 76 return; 77 // Get allocation result. 78 Value allocValue = allocateResultEffects[0].getValue(); 79 // Find the associated dealloc value and register the allocation entry. 80 llvm::Optional<Operation *> dealloc = findDealloc(allocValue); 81 // If the allocation has > 1 dealloc associated with it, skip handling it. 82 if (!dealloc.hasValue()) 83 return; 84 allocs.push_back(std::make_tuple(allocValue, *dealloc)); 85 }); 86 } 87 88 //===----------------------------------------------------------------------===// 89 // BufferPlacementTransformationBase 90 //===----------------------------------------------------------------------===// 91 92 /// Constructs a new transformation base using the given root operation. 93 BufferPlacementTransformationBase::BufferPlacementTransformationBase( 94 Operation *op) 95 : aliases(op), allocs(op), liveness(op) {} 96 97 /// Returns true if the given operation represents a loop by testing whether it 98 /// implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`. In 99 /// the case of a `RegionBranchOpInterface`, it checks all region-based control- 100 /// flow edges for cycles. 101 bool BufferPlacementTransformationBase::isLoop(Operation *op) { 102 // If the operation implements the `LoopLikeOpInterface` it can be considered 103 // a loop. 104 if (isa<LoopLikeOpInterface>(op)) 105 return true; 106 107 // If the operation does not implement the `RegionBranchOpInterface`, it is 108 // (currently) not possible to detect a loop. 109 RegionBranchOpInterface regionInterface; 110 if (!(regionInterface = dyn_cast<RegionBranchOpInterface>(op))) 111 return false; 112 113 // Recurses into a region using the current region interface to find potential 114 // cycles. 115 SmallPtrSet<Region *, 4> visitedRegions; 116 std::function<bool(Region *)> recurse = [&](Region *current) { 117 if (!current) 118 return false; 119 // If we have found a back edge, the parent operation induces a loop. 120 if (!visitedRegions.insert(current).second) 121 return true; 122 // Recurses into all region successors. 123 SmallVector<RegionSuccessor, 2> successors; 124 regionInterface.getSuccessorRegions(current->getRegionNumber(), successors); 125 for (RegionSuccessor ®ionEntry : successors) 126 if (recurse(regionEntry.getSuccessor())) 127 return true; 128 return false; 129 }; 130 131 // Start with all entry regions and test whether they induce a loop. 132 SmallVector<RegionSuccessor, 2> successorRegions; 133 regionInterface.getSuccessorRegions(/*index=*/llvm::None, successorRegions); 134 for (RegionSuccessor ®ionEntry : successorRegions) { 135 if (recurse(regionEntry.getSuccessor())) 136 return true; 137 visitedRegions.clear(); 138 } 139 140 return false; 141 } 142 143 //===----------------------------------------------------------------------===// 144 // BufferPlacementTransformationBase 145 //===----------------------------------------------------------------------===// 146 147 memref::GlobalOp GlobalCreator::getGlobalFor(arith::ConstantOp constantOp) { 148 auto type = constantOp.getType().cast<RankedTensorType>(); 149 150 BufferizeTypeConverter typeConverter; 151 152 // If we already have a global for this constant value, no need to do 153 // anything else. 154 auto it = globals.find(constantOp.getValue()); 155 if (it != globals.end()) 156 return cast<memref::GlobalOp>(it->second); 157 158 // Create a builder without an insertion point. We will insert using the 159 // symbol table to guarantee unique names. 160 OpBuilder globalBuilder(moduleOp.getContext()); 161 SymbolTable symbolTable(moduleOp); 162 163 // Create a pretty name. 164 SmallString<64> buf; 165 llvm::raw_svector_ostream os(buf); 166 interleave(type.getShape(), os, "x"); 167 os << "x" << type.getElementType(); 168 169 // Add an optional alignment to the global memref. 170 IntegerAttr memrefAlignment = 171 alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment) 172 : IntegerAttr(); 173 174 auto global = globalBuilder.create<memref::GlobalOp>( 175 constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), 176 /*sym_visibility=*/globalBuilder.getStringAttr("private"), 177 /*type=*/typeConverter.convertType(type).cast<MemRefType>(), 178 /*initial_value=*/constantOp.getValue().cast<ElementsAttr>(), 179 /*constant=*/true, 180 /*alignment=*/memrefAlignment); 181 symbolTable.insert(global); 182 // The symbol table inserts at the end of the module, but globals are a bit 183 // nicer if they are at the beginning. 184 global->moveBefore(&moduleOp.front()); 185 globals[constantOp.getValue()] = global; 186 return global; 187 } 188