1 //===- OpenACCToLLVM.cpp - Prepare OpenACC data for LLVM translation ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "../PassDetail.h" 10 #include "mlir/Conversion/LLVMCommon/Pattern.h" 11 #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h" 12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 13 #include "mlir/Dialect/OpenACC/OpenACC.h" 14 #include "mlir/IR/Builders.h" 15 16 using namespace mlir; 17 18 //===----------------------------------------------------------------------===// 19 // DataDescriptor implementation 20 //===----------------------------------------------------------------------===// 21 22 constexpr StringRef getStructName() { return "openacc_data"; } 23 24 /// Construct a helper for the given descriptor value. 25 DataDescriptor::DataDescriptor(Value descriptor) : StructBuilder(descriptor) { 26 assert(value != nullptr && "value cannot be null"); 27 } 28 29 /// Builds IR creating an `undef` value of the data descriptor. 30 DataDescriptor DataDescriptor::undef(OpBuilder &builder, Location loc, 31 Type basePtrTy, Type ptrTy) { 32 Type descriptorType = LLVM::LLVMStructType::getNewIdentified( 33 builder.getContext(), getStructName(), 34 {basePtrTy, ptrTy, builder.getI64Type()}); 35 Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType); 36 return DataDescriptor(descriptor); 37 } 38 39 /// Check whether the type is a valid data descriptor. 40 bool DataDescriptor::isValid(Value descriptor) { 41 if (auto type = descriptor.getType().dyn_cast<LLVM::LLVMStructType>()) { 42 if (type.isIdentified() && type.getName().startswith(getStructName()) && 43 type.getBody().size() == 3 && 44 (type.getBody()[kPtrBasePosInDataDescriptor] 45 .isa<LLVM::LLVMPointerType>() || 46 type.getBody()[kPtrBasePosInDataDescriptor] 47 .isa<LLVM::LLVMStructType>()) && 48 type.getBody()[kPtrPosInDataDescriptor].isa<LLVM::LLVMPointerType>() && 49 type.getBody()[kSizePosInDataDescriptor].isInteger(64)) 50 return true; 51 } 52 return false; 53 } 54 55 /// Builds IR inserting the base pointer value into the descriptor. 56 void DataDescriptor::setBasePointer(OpBuilder &builder, Location loc, 57 Value basePtr) { 58 setPtr(builder, loc, kPtrBasePosInDataDescriptor, basePtr); 59 } 60 61 /// Builds IR inserting the pointer value into the descriptor. 62 void DataDescriptor::setPointer(OpBuilder &builder, Location loc, Value ptr) { 63 setPtr(builder, loc, kPtrPosInDataDescriptor, ptr); 64 } 65 66 /// Builds IR inserting the size value into the descriptor. 67 void DataDescriptor::setSize(OpBuilder &builder, Location loc, Value size) { 68 setPtr(builder, loc, kSizePosInDataDescriptor, size); 69 } 70 71 //===----------------------------------------------------------------------===// 72 // Conversion patterns 73 //===----------------------------------------------------------------------===// 74 75 namespace { 76 77 template <typename Op> 78 class LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern<Op> { 79 using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern; 80 81 LogicalResult 82 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 83 ConversionPatternRewriter &builder) const override { 84 Location loc = op.getLoc(); 85 TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); 86 87 unsigned numDataOperand = op.getNumDataOperands(); 88 89 // Keep the non data operands without modification. 90 auto nonDataOperands = adaptor.getOperands().take_front( 91 adaptor.getOperands().size() - numDataOperand); 92 SmallVector<Value> convertedOperands; 93 convertedOperands.append(nonDataOperands.begin(), nonDataOperands.end()); 94 95 // Go over the data operand and legalize them for translation. 96 for (unsigned idx = 0; idx < numDataOperand; ++idx) { 97 Value originalDataOperand = op.getDataOperand(idx); 98 99 // Traverse operands that were converted to MemRefDescriptors. 100 if (auto memRefType = 101 originalDataOperand.getType().dyn_cast<MemRefType>()) { 102 Type structType = converter->convertType(memRefType); 103 Value memRefDescriptor = builder 104 .create<UnrealizedConversionCastOp>( 105 loc, structType, originalDataOperand) 106 .getResult(0); 107 108 // Calculate the size of the memref and get the pointer to the allocated 109 // buffer. 110 SmallVector<Value> sizes; 111 SmallVector<Value> strides; 112 Value sizeBytes; 113 ConvertToLLVMPattern::getMemRefDescriptorSizes( 114 loc, memRefType, {}, builder, sizes, strides, sizeBytes); 115 MemRefDescriptor descriptor(memRefDescriptor); 116 Value dataPtr = descriptor.alignedPtr(builder, loc); 117 auto ptrType = descriptor.getElementPtrType(); 118 119 auto descr = DataDescriptor::undef(builder, loc, structType, ptrType); 120 descr.setBasePointer(builder, loc, memRefDescriptor); 121 descr.setPointer(builder, loc, dataPtr); 122 descr.setSize(builder, loc, sizeBytes); 123 convertedOperands.push_back(descr); 124 } else if (originalDataOperand.getType().isa<LLVM::LLVMPointerType>()) { 125 convertedOperands.push_back(originalDataOperand); 126 } else { 127 // Type not supported. 128 return builder.notifyMatchFailure(op, "unsupported type"); 129 } 130 } 131 132 builder.replaceOpWithNewOp<Op>(op, TypeRange(), convertedOperands, 133 op.getOperation()->getAttrs()); 134 135 return success(); 136 } 137 }; 138 } // namespace 139 140 void mlir::populateOpenACCToLLVMConversionPatterns( 141 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 142 patterns.add<LegalizeDataOpForLLVMTranslation<acc::DataOp>>(converter); 143 patterns.add<LegalizeDataOpForLLVMTranslation<acc::EnterDataOp>>(converter); 144 patterns.add<LegalizeDataOpForLLVMTranslation<acc::ExitDataOp>>(converter); 145 patterns.add<LegalizeDataOpForLLVMTranslation<acc::ParallelOp>>(converter); 146 patterns.add<LegalizeDataOpForLLVMTranslation<acc::UpdateOp>>(converter); 147 } 148 149 namespace { 150 struct ConvertOpenACCToLLVMPass 151 : public ConvertOpenACCToLLVMBase<ConvertOpenACCToLLVMPass> { 152 void runOnOperation() override; 153 }; 154 } // namespace 155 156 void ConvertOpenACCToLLVMPass::runOnOperation() { 157 auto op = getOperation(); 158 auto *context = op.getContext(); 159 160 // Convert to OpenACC operations with LLVM IR dialect 161 RewritePatternSet patterns(context); 162 LLVMTypeConverter converter(context); 163 populateOpenACCToLLVMConversionPatterns(converter, patterns); 164 165 ConversionTarget target(*context); 166 target.addLegalDialect<LLVM::LLVMDialect>(); 167 target.addLegalOp<UnrealizedConversionCastOp>(); 168 169 auto allDataOperandsAreConverted = [](ValueRange operands) { 170 for (Value operand : operands) { 171 if (!DataDescriptor::isValid(operand) && 172 !operand.getType().isa<LLVM::LLVMPointerType>()) 173 return false; 174 } 175 return true; 176 }; 177 178 target.addDynamicallyLegalOp<acc::DataOp>( 179 [allDataOperandsAreConverted](acc::DataOp op) { 180 return allDataOperandsAreConverted(op.copyOperands()) && 181 allDataOperandsAreConverted(op.copyinOperands()) && 182 allDataOperandsAreConverted(op.copyinReadonlyOperands()) && 183 allDataOperandsAreConverted(op.copyoutOperands()) && 184 allDataOperandsAreConverted(op.copyoutZeroOperands()) && 185 allDataOperandsAreConverted(op.createOperands()) && 186 allDataOperandsAreConverted(op.createZeroOperands()) && 187 allDataOperandsAreConverted(op.noCreateOperands()) && 188 allDataOperandsAreConverted(op.presentOperands()) && 189 allDataOperandsAreConverted(op.deviceptrOperands()) && 190 allDataOperandsAreConverted(op.attachOperands()); 191 }); 192 193 target.addDynamicallyLegalOp<acc::EnterDataOp>( 194 [allDataOperandsAreConverted](acc::EnterDataOp op) { 195 return allDataOperandsAreConverted(op.copyinOperands()) && 196 allDataOperandsAreConverted(op.createOperands()) && 197 allDataOperandsAreConverted(op.createZeroOperands()) && 198 allDataOperandsAreConverted(op.attachOperands()); 199 }); 200 201 target.addDynamicallyLegalOp<acc::ExitDataOp>( 202 [allDataOperandsAreConverted](acc::ExitDataOp op) { 203 return allDataOperandsAreConverted(op.copyoutOperands()) && 204 allDataOperandsAreConverted(op.deleteOperands()) && 205 allDataOperandsAreConverted(op.detachOperands()); 206 }); 207 208 target.addDynamicallyLegalOp<acc::ParallelOp>( 209 [allDataOperandsAreConverted](acc::ParallelOp op) { 210 return allDataOperandsAreConverted(op.reductionOperands()) && 211 allDataOperandsAreConverted(op.copyOperands()) && 212 allDataOperandsAreConverted(op.copyinOperands()) && 213 allDataOperandsAreConverted(op.copyinReadonlyOperands()) && 214 allDataOperandsAreConverted(op.copyoutOperands()) && 215 allDataOperandsAreConverted(op.copyoutZeroOperands()) && 216 allDataOperandsAreConverted(op.createOperands()) && 217 allDataOperandsAreConverted(op.createZeroOperands()) && 218 allDataOperandsAreConverted(op.noCreateOperands()) && 219 allDataOperandsAreConverted(op.presentOperands()) && 220 allDataOperandsAreConverted(op.devicePtrOperands()) && 221 allDataOperandsAreConverted(op.attachOperands()) && 222 allDataOperandsAreConverted(op.gangPrivateOperands()) && 223 allDataOperandsAreConverted(op.gangFirstPrivateOperands()); 224 }); 225 226 target.addDynamicallyLegalOp<acc::UpdateOp>( 227 [allDataOperandsAreConverted](acc::UpdateOp op) { 228 return allDataOperandsAreConverted(op.hostOperands()) && 229 allDataOperandsAreConverted(op.deviceOperands()); 230 }); 231 232 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 233 signalPassFailure(); 234 } 235 236 std::unique_ptr<OperationPass<ModuleOp>> 237 mlir::createConvertOpenACCToLLVMPass() { 238 return std::make_unique<ConvertOpenACCToLLVMPass>(); 239 } 240