1 //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "GPUOpsLowering.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/IR/Builders.h"
12 #include "llvm/Support/FormatVariadic.h"
13 
14 using namespace mlir;
15 
16 LogicalResult
17 GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
18                                    ConversionPatternRewriter &rewriter) const {
19   Location loc = gpuFuncOp.getLoc();
20 
21   SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
22   workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
23   for (auto en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
24     Value attribution = en.value();
25 
26     auto type = attribution.getType().dyn_cast<MemRefType>();
27     assert(type && type.hasStaticShape() && "unexpected type in attribution");
28 
29     uint64_t numElements = type.getNumElements();
30 
31     auto elementType =
32         typeConverter->convertType(type.getElementType()).template cast<Type>();
33     auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
34     std::string name = std::string(
35         llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
36     auto globalOp = rewriter.create<LLVM::GlobalOp>(
37         gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
38         LLVM::Linkage::Internal, name, /*value=*/Attribute(),
39         /*alignment=*/0, gpu::GPUDialect::getWorkgroupAddressSpace());
40     workgroupBuffers.push_back(globalOp);
41   }
42 
43   // Rewrite the original GPU function to an LLVM function.
44   auto funcType = typeConverter->convertType(gpuFuncOp.getType())
45                       .template cast<LLVM::LLVMPointerType>()
46                       .getElementType();
47 
48   // Remap proper input types.
49   TypeConverter::SignatureConversion signatureConversion(
50       gpuFuncOp.front().getNumArguments());
51   getTypeConverter()->convertFunctionSignature(
52       gpuFuncOp.getType(), /*isVariadic=*/false, signatureConversion);
53 
54   // Create the new function operation. Only copy those attributes that are
55   // not specific to function modeling.
56   SmallVector<NamedAttribute, 4> attributes;
57   for (const auto &attr : gpuFuncOp->getAttrs()) {
58     if (attr.getName() == SymbolTable::getSymbolAttrName() ||
59         attr.getName() == function_like_impl::getTypeAttrName() ||
60         attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
61       continue;
62     attributes.push_back(attr);
63   }
64   // Add a dialect specific kernel attribute in addition to GPU kernel
65   // attribute. The former is necessary for further translation while the
66   // latter is expected by gpu.launch_func.
67   if (gpuFuncOp.isKernel())
68     attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
69   auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
70       gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
71       LLVM::Linkage::External, /*dsoLocal*/ false, attributes);
72 
73   {
74     // Insert operations that correspond to converted workgroup and private
75     // memory attributions to the body of the function. This must operate on
76     // the original function, before the body region is inlined in the new
77     // function to maintain the relation between block arguments and the
78     // parent operation that assigns their semantics.
79     OpBuilder::InsertionGuard guard(rewriter);
80 
81     // Rewrite workgroup memory attributions to addresses of global buffers.
82     rewriter.setInsertionPointToStart(&gpuFuncOp.front());
83     unsigned numProperArguments = gpuFuncOp.getNumArguments();
84     auto i32Type = IntegerType::get(rewriter.getContext(), 32);
85 
86     Value zero = nullptr;
87     if (!workgroupBuffers.empty())
88       zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type,
89                                                rewriter.getI32IntegerAttr(0));
90     for (auto en : llvm::enumerate(workgroupBuffers)) {
91       LLVM::GlobalOp global = en.value();
92       Value address = rewriter.create<LLVM::AddressOfOp>(loc, global);
93       auto elementType =
94           global.getType().cast<LLVM::LLVMArrayType>().getElementType();
95       Value memory = rewriter.create<LLVM::GEPOp>(
96           loc, LLVM::LLVMPointerType::get(elementType, global.getAddrSpace()),
97           address, ArrayRef<Value>{zero, zero});
98 
99       // Build a memref descriptor pointing to the buffer to plug with the
100       // existing memref infrastructure. This may use more registers than
101       // otherwise necessary given that memref sizes are fixed, but we can try
102       // and canonicalize that away later.
103       Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
104       auto type = attribution.getType().cast<MemRefType>();
105       auto descr = MemRefDescriptor::fromStaticShape(
106           rewriter, loc, *getTypeConverter(), type, memory);
107       signatureConversion.remapInput(numProperArguments + en.index(), descr);
108     }
109 
110     // Rewrite private memory attributions to alloca'ed buffers.
111     unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
112     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
113     for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
114       Value attribution = en.value();
115       auto type = attribution.getType().cast<MemRefType>();
116       assert(type && type.hasStaticShape() && "unexpected type in attribution");
117 
118       // Explicitly drop memory space when lowering private memory
119       // attributions since NVVM models it as `alloca`s in the default
120       // memory space and does not support `alloca`s with addrspace(5).
121       auto ptrType = LLVM::LLVMPointerType::get(
122           typeConverter->convertType(type.getElementType())
123               .template cast<Type>(),
124           allocaAddrSpace);
125       Value numElements = rewriter.create<LLVM::ConstantOp>(
126           gpuFuncOp.getLoc(), int64Ty,
127           rewriter.getI64IntegerAttr(type.getNumElements()));
128       Value allocated = rewriter.create<LLVM::AllocaOp>(
129           gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0);
130       auto descr = MemRefDescriptor::fromStaticShape(
131           rewriter, loc, *getTypeConverter(), type, allocated);
132       signatureConversion.remapInput(
133           numProperArguments + numWorkgroupAttributions + en.index(), descr);
134     }
135   }
136 
137   // Move the region to the new function, update the entry block signature.
138   rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
139                               llvmFuncOp.end());
140   if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
141                                          &signatureConversion)))
142     return failure();
143 
144   rewriter.eraseOp(gpuFuncOp);
145   return success();
146 }
147