1 //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "GPUOpsLowering.h" 10 #include "mlir/Dialect/StandardOps/IR/Ops.h" 11 #include "mlir/IR/Builders.h" 12 #include "llvm/Support/FormatVariadic.h" 13 14 using namespace mlir; 15 16 LogicalResult 17 GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, 18 ConversionPatternRewriter &rewriter) const { 19 Location loc = gpuFuncOp.getLoc(); 20 21 SmallVector<LLVM::GlobalOp, 3> workgroupBuffers; 22 workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions()); 23 for (auto en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) { 24 Value attribution = en.value(); 25 26 auto type = attribution.getType().dyn_cast<MemRefType>(); 27 assert(type && type.hasStaticShape() && "unexpected type in attribution"); 28 29 uint64_t numElements = type.getNumElements(); 30 31 auto elementType = 32 typeConverter->convertType(type.getElementType()).template cast<Type>(); 33 auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements); 34 std::string name = std::string( 35 llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index())); 36 auto globalOp = rewriter.create<LLVM::GlobalOp>( 37 gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, 38 LLVM::Linkage::Internal, name, /*value=*/Attribute(), 39 /*alignment=*/0, gpu::GPUDialect::getWorkgroupAddressSpace()); 40 workgroupBuffers.push_back(globalOp); 41 } 42 43 // Rewrite the original GPU function to an LLVM function. 44 auto funcType = typeConverter->convertType(gpuFuncOp.getType()) 45 .template cast<LLVM::LLVMPointerType>() 46 .getElementType(); 47 48 // Remap proper input types. 49 TypeConverter::SignatureConversion signatureConversion( 50 gpuFuncOp.front().getNumArguments()); 51 getTypeConverter()->convertFunctionSignature( 52 gpuFuncOp.getType(), /*isVariadic=*/false, signatureConversion); 53 54 // Create the new function operation. Only copy those attributes that are 55 // not specific to function modeling. 56 SmallVector<NamedAttribute, 4> attributes; 57 for (const auto &attr : gpuFuncOp->getAttrs()) { 58 if (attr.getName() == SymbolTable::getSymbolAttrName() || 59 attr.getName() == function_like_impl::getTypeAttrName() || 60 attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName()) 61 continue; 62 attributes.push_back(attr); 63 } 64 // Add a dialect specific kernel attribute in addition to GPU kernel 65 // attribute. The former is necessary for further translation while the 66 // latter is expected by gpu.launch_func. 67 if (gpuFuncOp.isKernel()) 68 attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr()); 69 auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>( 70 gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, 71 LLVM::Linkage::External, /*dsoLocal*/ false, attributes); 72 73 { 74 // Insert operations that correspond to converted workgroup and private 75 // memory attributions to the body of the function. This must operate on 76 // the original function, before the body region is inlined in the new 77 // function to maintain the relation between block arguments and the 78 // parent operation that assigns their semantics. 79 OpBuilder::InsertionGuard guard(rewriter); 80 81 // Rewrite workgroup memory attributions to addresses of global buffers. 82 rewriter.setInsertionPointToStart(&gpuFuncOp.front()); 83 unsigned numProperArguments = gpuFuncOp.getNumArguments(); 84 auto i32Type = IntegerType::get(rewriter.getContext(), 32); 85 86 Value zero = nullptr; 87 if (!workgroupBuffers.empty()) 88 zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 89 rewriter.getI32IntegerAttr(0)); 90 for (auto en : llvm::enumerate(workgroupBuffers)) { 91 LLVM::GlobalOp global = en.value(); 92 Value address = rewriter.create<LLVM::AddressOfOp>(loc, global); 93 auto elementType = 94 global.getType().cast<LLVM::LLVMArrayType>().getElementType(); 95 Value memory = rewriter.create<LLVM::GEPOp>( 96 loc, LLVM::LLVMPointerType::get(elementType, global.getAddrSpace()), 97 address, ArrayRef<Value>{zero, zero}); 98 99 // Build a memref descriptor pointing to the buffer to plug with the 100 // existing memref infrastructure. This may use more registers than 101 // otherwise necessary given that memref sizes are fixed, but we can try 102 // and canonicalize that away later. 103 Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()]; 104 auto type = attribution.getType().cast<MemRefType>(); 105 auto descr = MemRefDescriptor::fromStaticShape( 106 rewriter, loc, *getTypeConverter(), type, memory); 107 signatureConversion.remapInput(numProperArguments + en.index(), descr); 108 } 109 110 // Rewrite private memory attributions to alloca'ed buffers. 111 unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions(); 112 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 113 for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) { 114 Value attribution = en.value(); 115 auto type = attribution.getType().cast<MemRefType>(); 116 assert(type && type.hasStaticShape() && "unexpected type in attribution"); 117 118 // Explicitly drop memory space when lowering private memory 119 // attributions since NVVM models it as `alloca`s in the default 120 // memory space and does not support `alloca`s with addrspace(5). 121 auto ptrType = LLVM::LLVMPointerType::get( 122 typeConverter->convertType(type.getElementType()) 123 .template cast<Type>(), 124 allocaAddrSpace); 125 Value numElements = rewriter.create<LLVM::ConstantOp>( 126 gpuFuncOp.getLoc(), int64Ty, 127 rewriter.getI64IntegerAttr(type.getNumElements())); 128 Value allocated = rewriter.create<LLVM::AllocaOp>( 129 gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0); 130 auto descr = MemRefDescriptor::fromStaticShape( 131 rewriter, loc, *getTypeConverter(), type, allocated); 132 signatureConversion.remapInput( 133 numProperArguments + numWorkgroupAttributions + en.index(), descr); 134 } 135 } 136 137 // Move the region to the new function, update the entry block signature. 138 rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(), 139 llvmFuncOp.end()); 140 if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter, 141 &signatureConversion))) 142 return failure(); 143 144 rewriter.eraseOp(gpuFuncOp); 145 return success(); 146 } 147