16110b667SValentin Clement //===- OpenACCToLLVM.cpp - Prepare OpenACC data for LLVM translation ------===//
26110b667SValentin Clement //
36110b667SValentin Clement // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
46110b667SValentin Clement // See https://llvm.org/LICENSE.txt for license information.
56110b667SValentin Clement // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66110b667SValentin Clement //
76110b667SValentin Clement //===----------------------------------------------------------------------===//
86110b667SValentin Clement 
96110b667SValentin Clement #include "../PassDetail.h"
1075e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h"
116110b667SValentin Clement #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
126110b667SValentin Clement #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
136110b667SValentin Clement #include "mlir/Dialect/OpenACC/OpenACC.h"
1475e5f0aaSAlex Zinenko #include "mlir/IR/Builders.h"
156110b667SValentin Clement 
166110b667SValentin Clement using namespace mlir;
176110b667SValentin Clement 
186110b667SValentin Clement //===----------------------------------------------------------------------===//
196110b667SValentin Clement // DataDescriptor implementation
206110b667SValentin Clement //===----------------------------------------------------------------------===//
216110b667SValentin Clement 
getStructName()226110b667SValentin Clement constexpr StringRef getStructName() { return "openacc_data"; }
236110b667SValentin Clement 
246110b667SValentin Clement /// Construct a helper for the given descriptor value.
DataDescriptor(Value descriptor)256110b667SValentin Clement DataDescriptor::DataDescriptor(Value descriptor) : StructBuilder(descriptor) {
266110b667SValentin Clement   assert(value != nullptr && "value cannot be null");
276110b667SValentin Clement }
286110b667SValentin Clement 
296110b667SValentin Clement /// Builds IR creating an `undef` value of the data descriptor.
undef(OpBuilder & builder,Location loc,Type basePtrTy,Type ptrTy)306110b667SValentin Clement DataDescriptor DataDescriptor::undef(OpBuilder &builder, Location loc,
316110b667SValentin Clement                                      Type basePtrTy, Type ptrTy) {
326110b667SValentin Clement   Type descriptorType = LLVM::LLVMStructType::getNewIdentified(
336110b667SValentin Clement       builder.getContext(), getStructName(),
346110b667SValentin Clement       {basePtrTy, ptrTy, builder.getI64Type()});
356110b667SValentin Clement   Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
366110b667SValentin Clement   return DataDescriptor(descriptor);
376110b667SValentin Clement }
386110b667SValentin Clement 
396110b667SValentin Clement /// Check whether the type is a valid data descriptor.
isValid(Value descriptor)406110b667SValentin Clement bool DataDescriptor::isValid(Value descriptor) {
416110b667SValentin Clement   if (auto type = descriptor.getType().dyn_cast<LLVM::LLVMStructType>()) {
426110b667SValentin Clement     if (type.isIdentified() && type.getName().startswith(getStructName()) &&
436110b667SValentin Clement         type.getBody().size() == 3 &&
446110b667SValentin Clement         (type.getBody()[kPtrBasePosInDataDescriptor]
456110b667SValentin Clement              .isa<LLVM::LLVMPointerType>() ||
466110b667SValentin Clement          type.getBody()[kPtrBasePosInDataDescriptor]
476110b667SValentin Clement              .isa<LLVM::LLVMStructType>()) &&
486110b667SValentin Clement         type.getBody()[kPtrPosInDataDescriptor].isa<LLVM::LLVMPointerType>() &&
496110b667SValentin Clement         type.getBody()[kSizePosInDataDescriptor].isInteger(64))
506110b667SValentin Clement       return true;
516110b667SValentin Clement   }
526110b667SValentin Clement   return false;
536110b667SValentin Clement }
546110b667SValentin Clement 
556110b667SValentin Clement /// Builds IR inserting the base pointer value into the descriptor.
setBasePointer(OpBuilder & builder,Location loc,Value basePtr)566110b667SValentin Clement void DataDescriptor::setBasePointer(OpBuilder &builder, Location loc,
576110b667SValentin Clement                                     Value basePtr) {
586110b667SValentin Clement   setPtr(builder, loc, kPtrBasePosInDataDescriptor, basePtr);
596110b667SValentin Clement }
606110b667SValentin Clement 
616110b667SValentin Clement /// Builds IR inserting the pointer value into the descriptor.
setPointer(OpBuilder & builder,Location loc,Value ptr)626110b667SValentin Clement void DataDescriptor::setPointer(OpBuilder &builder, Location loc, Value ptr) {
636110b667SValentin Clement   setPtr(builder, loc, kPtrPosInDataDescriptor, ptr);
646110b667SValentin Clement }
656110b667SValentin Clement 
666110b667SValentin Clement /// Builds IR inserting the size value into the descriptor.
setSize(OpBuilder & builder,Location loc,Value size)676110b667SValentin Clement void DataDescriptor::setSize(OpBuilder &builder, Location loc, Value size) {
686110b667SValentin Clement   setPtr(builder, loc, kSizePosInDataDescriptor, size);
696110b667SValentin Clement }
706110b667SValentin Clement 
716110b667SValentin Clement //===----------------------------------------------------------------------===//
726110b667SValentin Clement // Conversion patterns
736110b667SValentin Clement //===----------------------------------------------------------------------===//
746110b667SValentin Clement 
758fdfead7SValentin Clement namespace {
768fdfead7SValentin Clement 
776110b667SValentin Clement template <typename Op>
786110b667SValentin Clement class LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern<Op> {
796110b667SValentin Clement   using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
806110b667SValentin Clement 
816110b667SValentin Clement   LogicalResult
matchAndRewrite(Op op,typename Op::Adaptor adaptor,ConversionPatternRewriter & builder) const82ef976337SRiver Riddle   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
836110b667SValentin Clement                   ConversionPatternRewriter &builder) const override {
846110b667SValentin Clement     Location loc = op.getLoc();
856110b667SValentin Clement     TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
866110b667SValentin Clement 
876110b667SValentin Clement     unsigned numDataOperand = op.getNumDataOperands();
886110b667SValentin Clement 
896110b667SValentin Clement     // Keep the non data operands without modification.
90ef976337SRiver Riddle     auto nonDataOperands = adaptor.getOperands().take_front(
91ef976337SRiver Riddle         adaptor.getOperands().size() - numDataOperand);
926110b667SValentin Clement     SmallVector<Value> convertedOperands;
936110b667SValentin Clement     convertedOperands.append(nonDataOperands.begin(), nonDataOperands.end());
946110b667SValentin Clement 
956110b667SValentin Clement     // Go over the data operand and legalize them for translation.
966110b667SValentin Clement     for (unsigned idx = 0; idx < numDataOperand; ++idx) {
976110b667SValentin Clement       Value originalDataOperand = op.getDataOperand(idx);
986110b667SValentin Clement 
996110b667SValentin Clement       // Traverse operands that were converted to MemRefDescriptors.
1006110b667SValentin Clement       if (auto memRefType =
1016110b667SValentin Clement               originalDataOperand.getType().dyn_cast<MemRefType>()) {
1026110b667SValentin Clement         Type structType = converter->convertType(memRefType);
1036110b667SValentin Clement         Value memRefDescriptor = builder
104881dc34fSAlex Zinenko                                      .create<UnrealizedConversionCastOp>(
1056110b667SValentin Clement                                          loc, structType, originalDataOperand)
106881dc34fSAlex Zinenko                                      .getResult(0);
1076110b667SValentin Clement 
1086110b667SValentin Clement         // Calculate the size of the memref and get the pointer to the allocated
1096110b667SValentin Clement         // buffer.
1106110b667SValentin Clement         SmallVector<Value> sizes;
1116110b667SValentin Clement         SmallVector<Value> strides;
1126110b667SValentin Clement         Value sizeBytes;
1136110b667SValentin Clement         ConvertToLLVMPattern::getMemRefDescriptorSizes(
1146110b667SValentin Clement             loc, memRefType, {}, builder, sizes, strides, sizeBytes);
1156110b667SValentin Clement         MemRefDescriptor descriptor(memRefDescriptor);
1166110b667SValentin Clement         Value dataPtr = descriptor.alignedPtr(builder, loc);
1176110b667SValentin Clement         auto ptrType = descriptor.getElementPtrType();
1186110b667SValentin Clement 
1196110b667SValentin Clement         auto descr = DataDescriptor::undef(builder, loc, structType, ptrType);
1206110b667SValentin Clement         descr.setBasePointer(builder, loc, memRefDescriptor);
1216110b667SValentin Clement         descr.setPointer(builder, loc, dataPtr);
1226110b667SValentin Clement         descr.setSize(builder, loc, sizeBytes);
1236110b667SValentin Clement         convertedOperands.push_back(descr);
1246110b667SValentin Clement       } else if (originalDataOperand.getType().isa<LLVM::LLVMPointerType>()) {
1256110b667SValentin Clement         convertedOperands.push_back(originalDataOperand);
1266110b667SValentin Clement       } else {
1276110b667SValentin Clement         // Type not supported.
1286110b667SValentin Clement         return builder.notifyMatchFailure(op, "unsupported type");
1296110b667SValentin Clement       }
1306110b667SValentin Clement     }
1316110b667SValentin Clement 
1326110b667SValentin Clement     builder.replaceOpWithNewOp<Op>(op, TypeRange(), convertedOperands,
1336110b667SValentin Clement                                    op.getOperation()->getAttrs());
1346110b667SValentin Clement 
1356110b667SValentin Clement     return success();
1366110b667SValentin Clement   }
1376110b667SValentin Clement };
1388fdfead7SValentin Clement } // namespace
1396110b667SValentin Clement 
populateOpenACCToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)1406110b667SValentin Clement void mlir::populateOpenACCToLLVMConversionPatterns(
141*9f85c198SRiver Riddle     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
142fcb15472SValentin Clement   patterns.add<LegalizeDataOpForLLVMTranslation<acc::DataOp>>(converter);
1436110b667SValentin Clement   patterns.add<LegalizeDataOpForLLVMTranslation<acc::EnterDataOp>>(converter);
1446110b667SValentin Clement   patterns.add<LegalizeDataOpForLLVMTranslation<acc::ExitDataOp>>(converter);
145cfcdebafSValentin Clement   patterns.add<LegalizeDataOpForLLVMTranslation<acc::ParallelOp>>(converter);
1466110b667SValentin Clement   patterns.add<LegalizeDataOpForLLVMTranslation<acc::UpdateOp>>(converter);
1476110b667SValentin Clement }
1486110b667SValentin Clement 
1496110b667SValentin Clement namespace {
1506110b667SValentin Clement struct ConvertOpenACCToLLVMPass
1516110b667SValentin Clement     : public ConvertOpenACCToLLVMBase<ConvertOpenACCToLLVMPass> {
1526110b667SValentin Clement   void runOnOperation() override;
1536110b667SValentin Clement };
1546110b667SValentin Clement } // namespace
1556110b667SValentin Clement 
runOnOperation()1566110b667SValentin Clement void ConvertOpenACCToLLVMPass::runOnOperation() {
1576110b667SValentin Clement   auto op = getOperation();
1586110b667SValentin Clement   auto *context = op.getContext();
1596110b667SValentin Clement 
1606110b667SValentin Clement   // Convert to OpenACC operations with LLVM IR dialect
1616110b667SValentin Clement   RewritePatternSet patterns(context);
1626110b667SValentin Clement   LLVMTypeConverter converter(context);
1636110b667SValentin Clement   populateOpenACCToLLVMConversionPatterns(converter, patterns);
1646110b667SValentin Clement 
1656110b667SValentin Clement   ConversionTarget target(*context);
1666110b667SValentin Clement   target.addLegalDialect<LLVM::LLVMDialect>();
167881dc34fSAlex Zinenko   target.addLegalOp<UnrealizedConversionCastOp>();
1686110b667SValentin Clement 
1696110b667SValentin Clement   auto allDataOperandsAreConverted = [](ValueRange operands) {
1706110b667SValentin Clement     for (Value operand : operands) {
1716110b667SValentin Clement       if (!DataDescriptor::isValid(operand) &&
1726110b667SValentin Clement           !operand.getType().isa<LLVM::LLVMPointerType>())
1736110b667SValentin Clement         return false;
1746110b667SValentin Clement     }
1756110b667SValentin Clement     return true;
1766110b667SValentin Clement   };
1776110b667SValentin Clement 
178fcb15472SValentin Clement   target.addDynamicallyLegalOp<acc::DataOp>(
179fcb15472SValentin Clement       [allDataOperandsAreConverted](acc::DataOp op) {
180fcb15472SValentin Clement         return allDataOperandsAreConverted(op.copyOperands()) &&
181fcb15472SValentin Clement                allDataOperandsAreConverted(op.copyinOperands()) &&
182fcb15472SValentin Clement                allDataOperandsAreConverted(op.copyinReadonlyOperands()) &&
183fcb15472SValentin Clement                allDataOperandsAreConverted(op.copyoutOperands()) &&
184fcb15472SValentin Clement                allDataOperandsAreConverted(op.copyoutZeroOperands()) &&
185fcb15472SValentin Clement                allDataOperandsAreConverted(op.createOperands()) &&
186fcb15472SValentin Clement                allDataOperandsAreConverted(op.createZeroOperands()) &&
187fcb15472SValentin Clement                allDataOperandsAreConverted(op.noCreateOperands()) &&
188fcb15472SValentin Clement                allDataOperandsAreConverted(op.presentOperands()) &&
189fcb15472SValentin Clement                allDataOperandsAreConverted(op.deviceptrOperands()) &&
190fcb15472SValentin Clement                allDataOperandsAreConverted(op.attachOperands());
191fcb15472SValentin Clement       });
192fcb15472SValentin Clement 
1936110b667SValentin Clement   target.addDynamicallyLegalOp<acc::EnterDataOp>(
1946110b667SValentin Clement       [allDataOperandsAreConverted](acc::EnterDataOp op) {
1956110b667SValentin Clement         return allDataOperandsAreConverted(op.copyinOperands()) &&
1966110b667SValentin Clement                allDataOperandsAreConverted(op.createOperands()) &&
1976110b667SValentin Clement                allDataOperandsAreConverted(op.createZeroOperands()) &&
1986110b667SValentin Clement                allDataOperandsAreConverted(op.attachOperands());
1996110b667SValentin Clement       });
2006110b667SValentin Clement 
2016110b667SValentin Clement   target.addDynamicallyLegalOp<acc::ExitDataOp>(
2026110b667SValentin Clement       [allDataOperandsAreConverted](acc::ExitDataOp op) {
2036110b667SValentin Clement         return allDataOperandsAreConverted(op.copyoutOperands()) &&
2046110b667SValentin Clement                allDataOperandsAreConverted(op.deleteOperands()) &&
2056110b667SValentin Clement                allDataOperandsAreConverted(op.detachOperands());
2066110b667SValentin Clement       });
2076110b667SValentin Clement 
208cfcdebafSValentin Clement   target.addDynamicallyLegalOp<acc::ParallelOp>(
209cfcdebafSValentin Clement       [allDataOperandsAreConverted](acc::ParallelOp op) {
210cfcdebafSValentin Clement         return allDataOperandsAreConverted(op.reductionOperands()) &&
211cfcdebafSValentin Clement                allDataOperandsAreConverted(op.copyOperands()) &&
212cfcdebafSValentin Clement                allDataOperandsAreConverted(op.copyinOperands()) &&
213cfcdebafSValentin Clement                allDataOperandsAreConverted(op.copyinReadonlyOperands()) &&
214cfcdebafSValentin Clement                allDataOperandsAreConverted(op.copyoutOperands()) &&
215cfcdebafSValentin Clement                allDataOperandsAreConverted(op.copyoutZeroOperands()) &&
216cfcdebafSValentin Clement                allDataOperandsAreConverted(op.createOperands()) &&
217cfcdebafSValentin Clement                allDataOperandsAreConverted(op.createZeroOperands()) &&
218cfcdebafSValentin Clement                allDataOperandsAreConverted(op.noCreateOperands()) &&
219cfcdebafSValentin Clement                allDataOperandsAreConverted(op.presentOperands()) &&
220cfcdebafSValentin Clement                allDataOperandsAreConverted(op.devicePtrOperands()) &&
221cfcdebafSValentin Clement                allDataOperandsAreConverted(op.attachOperands()) &&
222cfcdebafSValentin Clement                allDataOperandsAreConverted(op.gangPrivateOperands()) &&
223cfcdebafSValentin Clement                allDataOperandsAreConverted(op.gangFirstPrivateOperands());
224cfcdebafSValentin Clement       });
225cfcdebafSValentin Clement 
2266110b667SValentin Clement   target.addDynamicallyLegalOp<acc::UpdateOp>(
2276110b667SValentin Clement       [allDataOperandsAreConverted](acc::UpdateOp op) {
2286110b667SValentin Clement         return allDataOperandsAreConverted(op.hostOperands()) &&
2296110b667SValentin Clement                allDataOperandsAreConverted(op.deviceOperands());
2306110b667SValentin Clement       });
2316110b667SValentin Clement 
2326110b667SValentin Clement   if (failed(applyPartialConversion(op, target, std::move(patterns))))
2336110b667SValentin Clement     signalPassFailure();
2346110b667SValentin Clement }
2356110b667SValentin Clement 
2366110b667SValentin Clement std::unique_ptr<OperationPass<ModuleOp>>
createConvertOpenACCToLLVMPass()2376110b667SValentin Clement mlir::createConvertOpenACCToLLVMPass() {
2386110b667SValentin Clement   return std::make_unique<ConvertOpenACCToLLVMPass>();
2396110b667SValentin Clement }
240