108778d8cSAlex Zinenko //===- MemoryPromotion.cpp - Utilities for moving data across GPU memories ===// 208778d8cSAlex Zinenko // 308778d8cSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 408778d8cSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 508778d8cSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 608778d8cSAlex Zinenko // 708778d8cSAlex Zinenko //===----------------------------------------------------------------------===// 808778d8cSAlex Zinenko // 908778d8cSAlex Zinenko // This file implements utilities that allow one to create IR moving the data 1008778d8cSAlex Zinenko // across different levels of the GPU memory hierarchy. 1108778d8cSAlex Zinenko // 1208778d8cSAlex Zinenko //===----------------------------------------------------------------------===// 1308778d8cSAlex Zinenko 14*d7ef488bSMogball #include "mlir/Dialect/GPU/Transforms/MemoryPromotion.h" 15*d7ef488bSMogball 16a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/LoopUtils.h" 17a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 18*d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h" 198eb18a0fSNicolas Vasilache #include "mlir/Dialect/MemRef/IR/MemRef.h" 2084a880e1SNicolas Vasilache #include "mlir/Dialect/SCF/SCF.h" 21e3cf7c88SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h" 2208778d8cSAlex Zinenko #include "mlir/Pass/Pass.h" 2308778d8cSAlex Zinenko 2408778d8cSAlex Zinenko using namespace mlir; 2508778d8cSAlex Zinenko using namespace mlir::gpu; 2608778d8cSAlex Zinenko 2708778d8cSAlex Zinenko /// Emits the (imperfect) loop nest performing the copy between "from" and "to" 2808778d8cSAlex Zinenko /// values using the bounds derived from the "from" value. Emits at least 2908778d8cSAlex Zinenko /// GPUDialect::getNumWorkgroupDimensions() loops, completing the nest with 3008778d8cSAlex Zinenko /// single-iteration loops. Maps the innermost loops to thread dimensions, in 3108778d8cSAlex Zinenko /// reverse order to enable access coalescing in the innermost loop. 32e3cf7c88SNicolas Vasilache static void insertCopyLoops(ImplicitLocOpBuilder &b, Value from, Value to) { 33e3cf7c88SNicolas Vasilache auto memRefType = from.getType().cast<MemRefType>(); 34e3cf7c88SNicolas Vasilache auto rank = memRefType.getRank(); 35e3cf7c88SNicolas Vasilache 36367229e1SNicolas Vasilache SmallVector<Value, 4> lbs, ubs, steps; 37a54f4eaeSMogball Value zero = b.create<arith::ConstantIndexOp>(0); 38a54f4eaeSMogball Value one = b.create<arith::ConstantIndexOp>(1); 3908778d8cSAlex Zinenko 4008778d8cSAlex Zinenko // Make sure we have enough loops to use all thread dimensions, these trivial 4108778d8cSAlex Zinenko // loops should be outermost and therefore inserted first. 4208778d8cSAlex Zinenko if (rank < GPUDialect::getNumWorkgroupDimensions()) { 4308778d8cSAlex Zinenko unsigned extraLoops = GPUDialect::getNumWorkgroupDimensions() - rank; 4408778d8cSAlex Zinenko lbs.resize(extraLoops, zero); 4508778d8cSAlex Zinenko ubs.resize(extraLoops, one); 4608778d8cSAlex Zinenko steps.resize(extraLoops, one); 4708778d8cSAlex Zinenko } 4808778d8cSAlex Zinenko 4973f371c3SKazuaki Ishizaki // Add existing bounds. 50e3cf7c88SNicolas Vasilache lbs.append(rank, zero); 51e3cf7c88SNicolas Vasilache ubs.reserve(lbs.size()); 5208778d8cSAlex Zinenko steps.reserve(lbs.size()); 53e3cf7c88SNicolas Vasilache for (auto idx = 0; idx < rank; ++idx) { 54a54f4eaeSMogball ubs.push_back(b.createOrFold<memref::DimOp>( 55a54f4eaeSMogball from, b.create<arith::ConstantIndexOp>(idx))); 56e3cf7c88SNicolas Vasilache steps.push_back(one); 57e3cf7c88SNicolas Vasilache } 5808778d8cSAlex Zinenko 5908778d8cSAlex Zinenko // Obtain thread identifiers and block sizes, necessary to map to them. 6084a880e1SNicolas Vasilache auto indexType = b.getIndexType(); 6108778d8cSAlex Zinenko SmallVector<Value, 3> threadIds, blockDims; 62aae51255SMogball for (auto dim : {gpu::Dimension::x, gpu::Dimension::y, gpu::Dimension::z}) { 63aae51255SMogball threadIds.push_back(b.create<gpu::ThreadIdOp>(indexType, dim)); 64aae51255SMogball blockDims.push_back(b.create<gpu::BlockDimOp>(indexType, dim)); 6508778d8cSAlex Zinenko } 6608778d8cSAlex Zinenko 6708778d8cSAlex Zinenko // Produce the loop nest with copies. 68367229e1SNicolas Vasilache SmallVector<Value, 8> ivs(lbs.size()); 6984a880e1SNicolas Vasilache mlir::scf::buildLoopNest( 70e3cf7c88SNicolas Vasilache b, b.getLoc(), lbs, ubs, steps, 7184a880e1SNicolas Vasilache [&](OpBuilder &b, Location loc, ValueRange loopIvs) { 72d1560f39SAlex Zinenko ivs.assign(loopIvs.begin(), loopIvs.end()); 7308778d8cSAlex Zinenko auto activeIvs = llvm::makeArrayRef(ivs).take_back(rank); 7484a880e1SNicolas Vasilache Value loaded = b.create<memref::LoadOp>(loc, from, activeIvs); 7584a880e1SNicolas Vasilache b.create<memref::StoreOp>(loc, loaded, to, activeIvs); 7608778d8cSAlex Zinenko }); 7708778d8cSAlex Zinenko 7808778d8cSAlex Zinenko // Map the innermost loops to threads in reverse order. 79e4853be2SMehdi Amini for (const auto &en : 8008778d8cSAlex Zinenko llvm::enumerate(llvm::reverse(llvm::makeArrayRef(ivs).take_back( 8108778d8cSAlex Zinenko GPUDialect::getNumWorkgroupDimensions())))) { 82367229e1SNicolas Vasilache Value v = en.value(); 83c25b20c0SAlex Zinenko auto loop = cast<scf::ForOp>(v.getParentRegion()->getParentOp()); 8408778d8cSAlex Zinenko mapLoopToProcessorIds(loop, {threadIds[en.index()]}, 8508778d8cSAlex Zinenko {blockDims[en.index()]}); 8608778d8cSAlex Zinenko } 8708778d8cSAlex Zinenko } 8808778d8cSAlex Zinenko 8908778d8cSAlex Zinenko /// Emits the loop nests performing the copy to the designated location in the 9008778d8cSAlex Zinenko /// beginning of the region, and from the designated location immediately before 9108778d8cSAlex Zinenko /// the terminator of the first block of the region. The region is expected to 9208778d8cSAlex Zinenko /// have one block. This boils down to the following structure 9308778d8cSAlex Zinenko /// 9408778d8cSAlex Zinenko /// ^bb(...): 9508778d8cSAlex Zinenko /// <loop-bound-computation> 9608778d8cSAlex Zinenko /// for %arg0 = ... to ... step ... { 9708778d8cSAlex Zinenko /// ... 9808778d8cSAlex Zinenko /// for %argN = <thread-id-x> to ... step <block-dim-x> { 9908778d8cSAlex Zinenko /// %0 = load %from[%arg0, ..., %argN] 10008778d8cSAlex Zinenko /// store %0, %to[%arg0, ..., %argN] 10108778d8cSAlex Zinenko /// } 10208778d8cSAlex Zinenko /// ... 10308778d8cSAlex Zinenko /// } 10408778d8cSAlex Zinenko /// gpu.barrier 10508778d8cSAlex Zinenko /// <... original body ...> 10608778d8cSAlex Zinenko /// gpu.barrier 10708778d8cSAlex Zinenko /// for %arg0 = ... to ... step ... { 10808778d8cSAlex Zinenko /// ... 10908778d8cSAlex Zinenko /// for %argN = <thread-id-x> to ... step <block-dim-x> { 11008778d8cSAlex Zinenko /// %1 = load %to[%arg0, ..., %argN] 11108778d8cSAlex Zinenko /// store %1, %from[%arg0, ..., %argN] 11208778d8cSAlex Zinenko /// } 11308778d8cSAlex Zinenko /// ... 11408778d8cSAlex Zinenko /// } 11508778d8cSAlex Zinenko /// 11608778d8cSAlex Zinenko /// Inserts the barriers unconditionally since different threads may be copying 11708778d8cSAlex Zinenko /// values and reading them. An analysis would be required to eliminate barriers 11808778d8cSAlex Zinenko /// in case where value is only used by the thread that copies it. Both copies 11908778d8cSAlex Zinenko /// are inserted unconditionally, an analysis would be required to only copy 12008778d8cSAlex Zinenko /// live-in and live-out values when necessary. This copies the entire memref 12108778d8cSAlex Zinenko /// pointed to by "from". In case a smaller block would be sufficient, the 12208778d8cSAlex Zinenko /// caller can create a subview of the memref and promote it instead. 12308778d8cSAlex Zinenko static void insertCopies(Region ®ion, Location loc, Value from, Value to) { 12408778d8cSAlex Zinenko auto fromType = from.getType().cast<MemRefType>(); 12508778d8cSAlex Zinenko auto toType = to.getType().cast<MemRefType>(); 12608778d8cSAlex Zinenko (void)fromType; 12708778d8cSAlex Zinenko (void)toType; 12808778d8cSAlex Zinenko assert(fromType.getShape() == toType.getShape()); 12908778d8cSAlex Zinenko assert(fromType.getRank() != 0); 130204c3b55SRiver Riddle assert(llvm::hasSingleElement(region) && 13108778d8cSAlex Zinenko "unstructured control flow not supported"); 13208778d8cSAlex Zinenko 133e3cf7c88SNicolas Vasilache auto b = ImplicitLocOpBuilder::atBlockBegin(loc, ®ion.front()); 134e3cf7c88SNicolas Vasilache insertCopyLoops(b, from, to); 135e3cf7c88SNicolas Vasilache b.create<gpu::BarrierOp>(); 13608778d8cSAlex Zinenko 13784a880e1SNicolas Vasilache b.setInsertionPoint(®ion.front().back()); 138e3cf7c88SNicolas Vasilache b.create<gpu::BarrierOp>(); 139e3cf7c88SNicolas Vasilache insertCopyLoops(b, to, from); 14008778d8cSAlex Zinenko } 14108778d8cSAlex Zinenko 14208778d8cSAlex Zinenko /// Promotes a function argument to workgroup memory in the given function. The 14308778d8cSAlex Zinenko /// copies will be inserted in the beginning and in the end of the function. 14408778d8cSAlex Zinenko void mlir::promoteToWorkgroupMemory(GPUFuncOp op, unsigned arg) { 14508778d8cSAlex Zinenko Value value = op.getArgument(arg); 14608778d8cSAlex Zinenko auto type = value.getType().dyn_cast<MemRefType>(); 14708778d8cSAlex Zinenko assert(type && type.hasStaticShape() && "can only promote memrefs"); 14808778d8cSAlex Zinenko 149ad398164SWen-Heng (Jack) Chung // Get the type of the buffer in the workgroup memory. 150ad398164SWen-Heng (Jack) Chung int workgroupMemoryAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace(); 151ad398164SWen-Heng (Jack) Chung auto bufferType = MemRefType::get(type.getShape(), type.getElementType(), {}, 152ad398164SWen-Heng (Jack) Chung workgroupMemoryAddressSpace); 153e084679fSRiver Riddle Value attribution = op.addWorkgroupAttribution(bufferType, value.getLoc()); 15408778d8cSAlex Zinenko 15508778d8cSAlex Zinenko // Replace the uses first since only the original uses are currently present. 15608778d8cSAlex Zinenko // Then insert the copies. 15708778d8cSAlex Zinenko value.replaceAllUsesWith(attribution); 15808778d8cSAlex Zinenko insertCopies(op.getBody(), op.getLoc(), value, attribution); 15908778d8cSAlex Zinenko } 160