1 //===- OpenACCToLLVM.cpp - Prepare OpenACC data for LLVM translation ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "../PassDetail.h" 10 #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h" 11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 12 #include "mlir/Dialect/OpenACC/OpenACC.h" 13 14 using namespace mlir; 15 16 //===----------------------------------------------------------------------===// 17 // DataDescriptor implementation 18 //===----------------------------------------------------------------------===// 19 20 constexpr StringRef getStructName() { return "openacc_data"; } 21 22 /// Construct a helper for the given descriptor value. 23 DataDescriptor::DataDescriptor(Value descriptor) : StructBuilder(descriptor) { 24 assert(value != nullptr && "value cannot be null"); 25 } 26 27 /// Builds IR creating an `undef` value of the data descriptor. 28 DataDescriptor DataDescriptor::undef(OpBuilder &builder, Location loc, 29 Type basePtrTy, Type ptrTy) { 30 Type descriptorType = LLVM::LLVMStructType::getNewIdentified( 31 builder.getContext(), getStructName(), 32 {basePtrTy, ptrTy, builder.getI64Type()}); 33 Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType); 34 return DataDescriptor(descriptor); 35 } 36 37 /// Check whether the type is a valid data descriptor. 38 bool DataDescriptor::isValid(Value descriptor) { 39 if (auto type = descriptor.getType().dyn_cast<LLVM::LLVMStructType>()) { 40 if (type.isIdentified() && type.getName().startswith(getStructName()) && 41 type.getBody().size() == 3 && 42 (type.getBody()[kPtrBasePosInDataDescriptor] 43 .isa<LLVM::LLVMPointerType>() || 44 type.getBody()[kPtrBasePosInDataDescriptor] 45 .isa<LLVM::LLVMStructType>()) && 46 type.getBody()[kPtrPosInDataDescriptor].isa<LLVM::LLVMPointerType>() && 47 type.getBody()[kSizePosInDataDescriptor].isInteger(64)) 48 return true; 49 } 50 return false; 51 } 52 53 /// Builds IR inserting the base pointer value into the descriptor. 54 void DataDescriptor::setBasePointer(OpBuilder &builder, Location loc, 55 Value basePtr) { 56 setPtr(builder, loc, kPtrBasePosInDataDescriptor, basePtr); 57 } 58 59 /// Builds IR inserting the pointer value into the descriptor. 60 void DataDescriptor::setPointer(OpBuilder &builder, Location loc, Value ptr) { 61 setPtr(builder, loc, kPtrPosInDataDescriptor, ptr); 62 } 63 64 /// Builds IR inserting the size value into the descriptor. 65 void DataDescriptor::setSize(OpBuilder &builder, Location loc, Value size) { 66 setPtr(builder, loc, kSizePosInDataDescriptor, size); 67 } 68 69 //===----------------------------------------------------------------------===// 70 // Conversion patterns 71 //===----------------------------------------------------------------------===// 72 73 template <typename Op> 74 class LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern<Op> { 75 using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern; 76 77 LogicalResult 78 matchAndRewrite(Op op, ArrayRef<Value> operands, 79 ConversionPatternRewriter &builder) const override { 80 Location loc = op.getLoc(); 81 TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); 82 83 unsigned numDataOperand = op.getNumDataOperands(); 84 85 // Keep the non data operands without modification. 86 auto nonDataOperands = 87 operands.take_front(operands.size() - numDataOperand); 88 SmallVector<Value> convertedOperands; 89 convertedOperands.append(nonDataOperands.begin(), nonDataOperands.end()); 90 91 // Go over the data operand and legalize them for translation. 92 for (unsigned idx = 0; idx < numDataOperand; ++idx) { 93 Value originalDataOperand = op.getDataOperand(idx); 94 95 // Traverse operands that were converted to MemRefDescriptors. 96 if (auto memRefType = 97 originalDataOperand.getType().dyn_cast<MemRefType>()) { 98 Type structType = converter->convertType(memRefType); 99 Value memRefDescriptor = builder 100 .create<LLVM::DialectCastOp>( 101 loc, structType, originalDataOperand) 102 .getResult(); 103 104 // Calculate the size of the memref and get the pointer to the allocated 105 // buffer. 106 SmallVector<Value> sizes; 107 SmallVector<Value> strides; 108 Value sizeBytes; 109 ConvertToLLVMPattern::getMemRefDescriptorSizes( 110 loc, memRefType, {}, builder, sizes, strides, sizeBytes); 111 MemRefDescriptor descriptor(memRefDescriptor); 112 Value dataPtr = descriptor.alignedPtr(builder, loc); 113 auto ptrType = descriptor.getElementPtrType(); 114 115 auto descr = DataDescriptor::undef(builder, loc, structType, ptrType); 116 descr.setBasePointer(builder, loc, memRefDescriptor); 117 descr.setPointer(builder, loc, dataPtr); 118 descr.setSize(builder, loc, sizeBytes); 119 convertedOperands.push_back(descr); 120 } else if (originalDataOperand.getType().isa<LLVM::LLVMPointerType>()) { 121 convertedOperands.push_back(originalDataOperand); 122 } else { 123 // Type not supported. 124 return builder.notifyMatchFailure(op, "unsupported type"); 125 } 126 } 127 128 builder.replaceOpWithNewOp<Op>(op, TypeRange(), convertedOperands, 129 op.getOperation()->getAttrs()); 130 131 return success(); 132 } 133 }; 134 135 void mlir::populateOpenACCToLLVMConversionPatterns( 136 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 137 patterns.add<LegalizeDataOpForLLVMTranslation<acc::EnterDataOp>>(converter); 138 patterns.add<LegalizeDataOpForLLVMTranslation<acc::ExitDataOp>>(converter); 139 patterns.add<LegalizeDataOpForLLVMTranslation<acc::UpdateOp>>(converter); 140 } 141 142 namespace { 143 struct ConvertOpenACCToLLVMPass 144 : public ConvertOpenACCToLLVMBase<ConvertOpenACCToLLVMPass> { 145 void runOnOperation() override; 146 }; 147 } // namespace 148 149 void ConvertOpenACCToLLVMPass::runOnOperation() { 150 auto op = getOperation(); 151 auto *context = op.getContext(); 152 153 // Convert to OpenACC operations with LLVM IR dialect 154 RewritePatternSet patterns(context); 155 LLVMTypeConverter converter(context); 156 populateOpenACCToLLVMConversionPatterns(converter, patterns); 157 158 ConversionTarget target(*context); 159 target.addLegalDialect<LLVM::LLVMDialect>(); 160 161 auto allDataOperandsAreConverted = [](ValueRange operands) { 162 for (Value operand : operands) { 163 if (!DataDescriptor::isValid(operand) && 164 !operand.getType().isa<LLVM::LLVMPointerType>()) 165 return false; 166 } 167 return true; 168 }; 169 170 target.addDynamicallyLegalOp<acc::EnterDataOp>( 171 [allDataOperandsAreConverted](acc::EnterDataOp op) { 172 return allDataOperandsAreConverted(op.copyinOperands()) && 173 allDataOperandsAreConverted(op.createOperands()) && 174 allDataOperandsAreConverted(op.createZeroOperands()) && 175 allDataOperandsAreConverted(op.attachOperands()); 176 }); 177 178 target.addDynamicallyLegalOp<acc::ExitDataOp>( 179 [allDataOperandsAreConverted](acc::ExitDataOp op) { 180 return allDataOperandsAreConverted(op.copyoutOperands()) && 181 allDataOperandsAreConverted(op.deleteOperands()) && 182 allDataOperandsAreConverted(op.detachOperands()); 183 }); 184 185 target.addDynamicallyLegalOp<acc::UpdateOp>( 186 [allDataOperandsAreConverted](acc::UpdateOp op) { 187 return allDataOperandsAreConverted(op.hostOperands()) && 188 allDataOperandsAreConverted(op.deviceOperands()); 189 }); 190 191 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 192 signalPassFailure(); 193 } 194 195 std::unique_ptr<OperationPass<ModuleOp>> 196 mlir::createConvertOpenACCToLLVMPass() { 197 return std::make_unique<ConvertOpenACCToLLVMPass>(); 198 } 199