1 //===- OpenACCToLLVM.cpp - Prepare OpenACC data for LLVM translation ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "../PassDetail.h"
10 #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12 #include "mlir/Dialect/OpenACC/OpenACC.h"
13 
14 using namespace mlir;
15 
16 //===----------------------------------------------------------------------===//
17 // DataDescriptor implementation
18 //===----------------------------------------------------------------------===//
19 
20 constexpr StringRef getStructName() { return "openacc_data"; }
21 
22 /// Construct a helper for the given descriptor value.
23 DataDescriptor::DataDescriptor(Value descriptor) : StructBuilder(descriptor) {
24   assert(value != nullptr && "value cannot be null");
25 }
26 
27 /// Builds IR creating an `undef` value of the data descriptor.
28 DataDescriptor DataDescriptor::undef(OpBuilder &builder, Location loc,
29                                      Type basePtrTy, Type ptrTy) {
30   Type descriptorType = LLVM::LLVMStructType::getNewIdentified(
31       builder.getContext(), getStructName(),
32       {basePtrTy, ptrTy, builder.getI64Type()});
33   Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
34   return DataDescriptor(descriptor);
35 }
36 
37 /// Check whether the type is a valid data descriptor.
38 bool DataDescriptor::isValid(Value descriptor) {
39   if (auto type = descriptor.getType().dyn_cast<LLVM::LLVMStructType>()) {
40     if (type.isIdentified() && type.getName().startswith(getStructName()) &&
41         type.getBody().size() == 3 &&
42         (type.getBody()[kPtrBasePosInDataDescriptor]
43              .isa<LLVM::LLVMPointerType>() ||
44          type.getBody()[kPtrBasePosInDataDescriptor]
45              .isa<LLVM::LLVMStructType>()) &&
46         type.getBody()[kPtrPosInDataDescriptor].isa<LLVM::LLVMPointerType>() &&
47         type.getBody()[kSizePosInDataDescriptor].isInteger(64))
48       return true;
49   }
50   return false;
51 }
52 
53 /// Builds IR inserting the base pointer value into the descriptor.
54 void DataDescriptor::setBasePointer(OpBuilder &builder, Location loc,
55                                     Value basePtr) {
56   setPtr(builder, loc, kPtrBasePosInDataDescriptor, basePtr);
57 }
58 
59 /// Builds IR inserting the pointer value into the descriptor.
60 void DataDescriptor::setPointer(OpBuilder &builder, Location loc, Value ptr) {
61   setPtr(builder, loc, kPtrPosInDataDescriptor, ptr);
62 }
63 
64 /// Builds IR inserting the size value into the descriptor.
65 void DataDescriptor::setSize(OpBuilder &builder, Location loc, Value size) {
66   setPtr(builder, loc, kSizePosInDataDescriptor, size);
67 }
68 
69 //===----------------------------------------------------------------------===//
70 // Conversion patterns
71 //===----------------------------------------------------------------------===//
72 
73 template <typename Op>
74 class LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern<Op> {
75   using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
76 
77   LogicalResult
78   matchAndRewrite(Op op, ArrayRef<Value> operands,
79                   ConversionPatternRewriter &builder) const override {
80     Location loc = op.getLoc();
81     TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
82 
83     unsigned numDataOperand = op.getNumDataOperands();
84 
85     // Keep the non data operands without modification.
86     auto nonDataOperands =
87         operands.take_front(operands.size() - numDataOperand);
88     SmallVector<Value> convertedOperands;
89     convertedOperands.append(nonDataOperands.begin(), nonDataOperands.end());
90 
91     // Go over the data operand and legalize them for translation.
92     for (unsigned idx = 0; idx < numDataOperand; ++idx) {
93       Value originalDataOperand = op.getDataOperand(idx);
94 
95       // Traverse operands that were converted to MemRefDescriptors.
96       if (auto memRefType =
97               originalDataOperand.getType().dyn_cast<MemRefType>()) {
98         Type structType = converter->convertType(memRefType);
99         Value memRefDescriptor = builder
100                                      .create<LLVM::DialectCastOp>(
101                                          loc, structType, originalDataOperand)
102                                      .getResult();
103 
104         // Calculate the size of the memref and get the pointer to the allocated
105         // buffer.
106         SmallVector<Value> sizes;
107         SmallVector<Value> strides;
108         Value sizeBytes;
109         ConvertToLLVMPattern::getMemRefDescriptorSizes(
110             loc, memRefType, {}, builder, sizes, strides, sizeBytes);
111         MemRefDescriptor descriptor(memRefDescriptor);
112         Value dataPtr = descriptor.alignedPtr(builder, loc);
113         auto ptrType = descriptor.getElementPtrType();
114 
115         auto descr = DataDescriptor::undef(builder, loc, structType, ptrType);
116         descr.setBasePointer(builder, loc, memRefDescriptor);
117         descr.setPointer(builder, loc, dataPtr);
118         descr.setSize(builder, loc, sizeBytes);
119         convertedOperands.push_back(descr);
120       } else if (originalDataOperand.getType().isa<LLVM::LLVMPointerType>()) {
121         convertedOperands.push_back(originalDataOperand);
122       } else {
123         // Type not supported.
124         return builder.notifyMatchFailure(op, "unsupported type");
125       }
126     }
127 
128     builder.replaceOpWithNewOp<Op>(op, TypeRange(), convertedOperands,
129                                    op.getOperation()->getAttrs());
130 
131     return success();
132   }
133 };
134 
135 void mlir::populateOpenACCToLLVMConversionPatterns(
136     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
137   patterns.add<LegalizeDataOpForLLVMTranslation<acc::EnterDataOp>>(converter);
138   patterns.add<LegalizeDataOpForLLVMTranslation<acc::ExitDataOp>>(converter);
139   patterns.add<LegalizeDataOpForLLVMTranslation<acc::UpdateOp>>(converter);
140 }
141 
142 namespace {
143 struct ConvertOpenACCToLLVMPass
144     : public ConvertOpenACCToLLVMBase<ConvertOpenACCToLLVMPass> {
145   void runOnOperation() override;
146 };
147 } // namespace
148 
149 void ConvertOpenACCToLLVMPass::runOnOperation() {
150   auto op = getOperation();
151   auto *context = op.getContext();
152 
153   // Convert to OpenACC operations with LLVM IR dialect
154   RewritePatternSet patterns(context);
155   LLVMTypeConverter converter(context);
156   populateOpenACCToLLVMConversionPatterns(converter, patterns);
157 
158   ConversionTarget target(*context);
159   target.addLegalDialect<LLVM::LLVMDialect>();
160 
161   auto allDataOperandsAreConverted = [](ValueRange operands) {
162     for (Value operand : operands) {
163       if (!DataDescriptor::isValid(operand) &&
164           !operand.getType().isa<LLVM::LLVMPointerType>())
165         return false;
166     }
167     return true;
168   };
169 
170   target.addDynamicallyLegalOp<acc::EnterDataOp>(
171       [allDataOperandsAreConverted](acc::EnterDataOp op) {
172         return allDataOperandsAreConverted(op.copyinOperands()) &&
173                allDataOperandsAreConverted(op.createOperands()) &&
174                allDataOperandsAreConverted(op.createZeroOperands()) &&
175                allDataOperandsAreConverted(op.attachOperands());
176       });
177 
178   target.addDynamicallyLegalOp<acc::ExitDataOp>(
179       [allDataOperandsAreConverted](acc::ExitDataOp op) {
180         return allDataOperandsAreConverted(op.copyoutOperands()) &&
181                allDataOperandsAreConverted(op.deleteOperands()) &&
182                allDataOperandsAreConverted(op.detachOperands());
183       });
184 
185   target.addDynamicallyLegalOp<acc::UpdateOp>(
186       [allDataOperandsAreConverted](acc::UpdateOp op) {
187         return allDataOperandsAreConverted(op.hostOperands()) &&
188                allDataOperandsAreConverted(op.deviceOperands());
189       });
190 
191   if (failed(applyPartialConversion(op, target, std::move(patterns))))
192     signalPassFailure();
193 }
194 
195 std::unique_ptr<OperationPass<ModuleOp>>
196 mlir::createConvertOpenACCToLLVMPass() {
197   return std::make_unique<ConvertOpenACCToLLVMPass>();
198 }
199