1 //===- OpenACCToLLVM.cpp - Prepare OpenACC data for LLVM translation ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "../PassDetail.h"
10 #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12 #include "mlir/Dialect/OpenACC/OpenACC.h"
13 
14 using namespace mlir;
15 
16 //===----------------------------------------------------------------------===//
17 // DataDescriptor implementation
18 //===----------------------------------------------------------------------===//
19 
20 constexpr StringRef getStructName() { return "openacc_data"; }
21 
22 /// Construct a helper for the given descriptor value.
23 DataDescriptor::DataDescriptor(Value descriptor) : StructBuilder(descriptor) {
24   assert(value != nullptr && "value cannot be null");
25 }
26 
27 /// Builds IR creating an `undef` value of the data descriptor.
28 DataDescriptor DataDescriptor::undef(OpBuilder &builder, Location loc,
29                                      Type basePtrTy, Type ptrTy) {
30   Type descriptorType = LLVM::LLVMStructType::getNewIdentified(
31       builder.getContext(), getStructName(),
32       {basePtrTy, ptrTy, builder.getI64Type()});
33   Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
34   return DataDescriptor(descriptor);
35 }
36 
37 /// Check whether the type is a valid data descriptor.
38 bool DataDescriptor::isValid(Value descriptor) {
39   if (auto type = descriptor.getType().dyn_cast<LLVM::LLVMStructType>()) {
40     if (type.isIdentified() && type.getName().startswith(getStructName()) &&
41         type.getBody().size() == 3 &&
42         (type.getBody()[kPtrBasePosInDataDescriptor]
43              .isa<LLVM::LLVMPointerType>() ||
44          type.getBody()[kPtrBasePosInDataDescriptor]
45              .isa<LLVM::LLVMStructType>()) &&
46         type.getBody()[kPtrPosInDataDescriptor].isa<LLVM::LLVMPointerType>() &&
47         type.getBody()[kSizePosInDataDescriptor].isInteger(64))
48       return true;
49   }
50   return false;
51 }
52 
53 /// Builds IR inserting the base pointer value into the descriptor.
54 void DataDescriptor::setBasePointer(OpBuilder &builder, Location loc,
55                                     Value basePtr) {
56   setPtr(builder, loc, kPtrBasePosInDataDescriptor, basePtr);
57 }
58 
59 /// Builds IR inserting the pointer value into the descriptor.
60 void DataDescriptor::setPointer(OpBuilder &builder, Location loc, Value ptr) {
61   setPtr(builder, loc, kPtrPosInDataDescriptor, ptr);
62 }
63 
64 /// Builds IR inserting the size value into the descriptor.
65 void DataDescriptor::setSize(OpBuilder &builder, Location loc, Value size) {
66   setPtr(builder, loc, kSizePosInDataDescriptor, size);
67 }
68 
69 //===----------------------------------------------------------------------===//
70 // Conversion patterns
71 //===----------------------------------------------------------------------===//
72 
73 namespace {
74 
75 template <typename Op>
76 class LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern<Op> {
77   using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
78 
79   LogicalResult
80   matchAndRewrite(Op op, ArrayRef<Value> operands,
81                   ConversionPatternRewriter &builder) const override {
82     Location loc = op.getLoc();
83     TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
84 
85     unsigned numDataOperand = op.getNumDataOperands();
86 
87     // Keep the non data operands without modification.
88     auto nonDataOperands =
89         operands.take_front(operands.size() - numDataOperand);
90     SmallVector<Value> convertedOperands;
91     convertedOperands.append(nonDataOperands.begin(), nonDataOperands.end());
92 
93     // Go over the data operand and legalize them for translation.
94     for (unsigned idx = 0; idx < numDataOperand; ++idx) {
95       Value originalDataOperand = op.getDataOperand(idx);
96 
97       // Traverse operands that were converted to MemRefDescriptors.
98       if (auto memRefType =
99               originalDataOperand.getType().dyn_cast<MemRefType>()) {
100         Type structType = converter->convertType(memRefType);
101         Value memRefDescriptor = builder
102                                      .create<LLVM::DialectCastOp>(
103                                          loc, structType, originalDataOperand)
104                                      .getResult();
105 
106         // Calculate the size of the memref and get the pointer to the allocated
107         // buffer.
108         SmallVector<Value> sizes;
109         SmallVector<Value> strides;
110         Value sizeBytes;
111         ConvertToLLVMPattern::getMemRefDescriptorSizes(
112             loc, memRefType, {}, builder, sizes, strides, sizeBytes);
113         MemRefDescriptor descriptor(memRefDescriptor);
114         Value dataPtr = descriptor.alignedPtr(builder, loc);
115         auto ptrType = descriptor.getElementPtrType();
116 
117         auto descr = DataDescriptor::undef(builder, loc, structType, ptrType);
118         descr.setBasePointer(builder, loc, memRefDescriptor);
119         descr.setPointer(builder, loc, dataPtr);
120         descr.setSize(builder, loc, sizeBytes);
121         convertedOperands.push_back(descr);
122       } else if (originalDataOperand.getType().isa<LLVM::LLVMPointerType>()) {
123         convertedOperands.push_back(originalDataOperand);
124       } else {
125         // Type not supported.
126         return builder.notifyMatchFailure(op, "unsupported type");
127       }
128     }
129 
130     builder.replaceOpWithNewOp<Op>(op, TypeRange(), convertedOperands,
131                                    op.getOperation()->getAttrs());
132 
133     return success();
134   }
135 };
136 } // namespace
137 
138 void mlir::populateOpenACCToLLVMConversionPatterns(
139     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
140   patterns.add<LegalizeDataOpForLLVMTranslation<acc::DataOp>>(converter);
141   patterns.add<LegalizeDataOpForLLVMTranslation<acc::EnterDataOp>>(converter);
142   patterns.add<LegalizeDataOpForLLVMTranslation<acc::ExitDataOp>>(converter);
143   patterns.add<LegalizeDataOpForLLVMTranslation<acc::ParallelOp>>(converter);
144   patterns.add<LegalizeDataOpForLLVMTranslation<acc::UpdateOp>>(converter);
145 }
146 
147 namespace {
148 struct ConvertOpenACCToLLVMPass
149     : public ConvertOpenACCToLLVMBase<ConvertOpenACCToLLVMPass> {
150   void runOnOperation() override;
151 };
152 } // namespace
153 
154 void ConvertOpenACCToLLVMPass::runOnOperation() {
155   auto op = getOperation();
156   auto *context = op.getContext();
157 
158   // Convert to OpenACC operations with LLVM IR dialect
159   RewritePatternSet patterns(context);
160   LLVMTypeConverter converter(context);
161   populateOpenACCToLLVMConversionPatterns(converter, patterns);
162 
163   ConversionTarget target(*context);
164   target.addLegalDialect<LLVM::LLVMDialect>();
165 
166   auto allDataOperandsAreConverted = [](ValueRange operands) {
167     for (Value operand : operands) {
168       if (!DataDescriptor::isValid(operand) &&
169           !operand.getType().isa<LLVM::LLVMPointerType>())
170         return false;
171     }
172     return true;
173   };
174 
175   target.addDynamicallyLegalOp<acc::DataOp>(
176       [allDataOperandsAreConverted](acc::DataOp op) {
177         return allDataOperandsAreConverted(op.copyOperands()) &&
178                allDataOperandsAreConverted(op.copyinOperands()) &&
179                allDataOperandsAreConverted(op.copyinReadonlyOperands()) &&
180                allDataOperandsAreConverted(op.copyoutOperands()) &&
181                allDataOperandsAreConverted(op.copyoutZeroOperands()) &&
182                allDataOperandsAreConverted(op.createOperands()) &&
183                allDataOperandsAreConverted(op.createZeroOperands()) &&
184                allDataOperandsAreConverted(op.noCreateOperands()) &&
185                allDataOperandsAreConverted(op.presentOperands()) &&
186                allDataOperandsAreConverted(op.deviceptrOperands()) &&
187                allDataOperandsAreConverted(op.attachOperands());
188       });
189 
190   target.addDynamicallyLegalOp<acc::EnterDataOp>(
191       [allDataOperandsAreConverted](acc::EnterDataOp op) {
192         return allDataOperandsAreConverted(op.copyinOperands()) &&
193                allDataOperandsAreConverted(op.createOperands()) &&
194                allDataOperandsAreConverted(op.createZeroOperands()) &&
195                allDataOperandsAreConverted(op.attachOperands());
196       });
197 
198   target.addDynamicallyLegalOp<acc::ExitDataOp>(
199       [allDataOperandsAreConverted](acc::ExitDataOp op) {
200         return allDataOperandsAreConverted(op.copyoutOperands()) &&
201                allDataOperandsAreConverted(op.deleteOperands()) &&
202                allDataOperandsAreConverted(op.detachOperands());
203       });
204 
205   target.addDynamicallyLegalOp<acc::ParallelOp>(
206       [allDataOperandsAreConverted](acc::ParallelOp op) {
207         return allDataOperandsAreConverted(op.reductionOperands()) &&
208                allDataOperandsAreConverted(op.copyOperands()) &&
209                allDataOperandsAreConverted(op.copyinOperands()) &&
210                allDataOperandsAreConverted(op.copyinReadonlyOperands()) &&
211                allDataOperandsAreConverted(op.copyoutOperands()) &&
212                allDataOperandsAreConverted(op.copyoutZeroOperands()) &&
213                allDataOperandsAreConverted(op.createOperands()) &&
214                allDataOperandsAreConverted(op.createZeroOperands()) &&
215                allDataOperandsAreConverted(op.noCreateOperands()) &&
216                allDataOperandsAreConverted(op.presentOperands()) &&
217                allDataOperandsAreConverted(op.devicePtrOperands()) &&
218                allDataOperandsAreConverted(op.attachOperands()) &&
219                allDataOperandsAreConverted(op.gangPrivateOperands()) &&
220                allDataOperandsAreConverted(op.gangFirstPrivateOperands());
221       });
222 
223   target.addDynamicallyLegalOp<acc::UpdateOp>(
224       [allDataOperandsAreConverted](acc::UpdateOp op) {
225         return allDataOperandsAreConverted(op.hostOperands()) &&
226                allDataOperandsAreConverted(op.deviceOperands());
227       });
228 
229   if (failed(applyPartialConversion(op, target, std::move(patterns))))
230     signalPassFailure();
231 }
232 
233 std::unique_ptr<OperationPass<ModuleOp>>
234 mlir::createConvertOpenACCToLLVMPass() {
235   return std::make_unique<ConvertOpenACCToLLVMPass>();
236 }
237