1 //===- OpenACCToLLVM.cpp - Prepare OpenACC data for LLVM translation ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "../PassDetail.h" 10 #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h" 11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 12 #include "mlir/Dialect/OpenACC/OpenACC.h" 13 14 using namespace mlir; 15 16 //===----------------------------------------------------------------------===// 17 // DataDescriptor implementation 18 //===----------------------------------------------------------------------===// 19 20 constexpr StringRef getStructName() { return "openacc_data"; } 21 22 /// Construct a helper for the given descriptor value. 23 DataDescriptor::DataDescriptor(Value descriptor) : StructBuilder(descriptor) { 24 assert(value != nullptr && "value cannot be null"); 25 } 26 27 /// Builds IR creating an `undef` value of the data descriptor. 28 DataDescriptor DataDescriptor::undef(OpBuilder &builder, Location loc, 29 Type basePtrTy, Type ptrTy) { 30 Type descriptorType = LLVM::LLVMStructType::getNewIdentified( 31 builder.getContext(), getStructName(), 32 {basePtrTy, ptrTy, builder.getI64Type()}); 33 Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType); 34 return DataDescriptor(descriptor); 35 } 36 37 /// Check whether the type is a valid data descriptor. 38 bool DataDescriptor::isValid(Value descriptor) { 39 if (auto type = descriptor.getType().dyn_cast<LLVM::LLVMStructType>()) { 40 if (type.isIdentified() && type.getName().startswith(getStructName()) && 41 type.getBody().size() == 3 && 42 (type.getBody()[kPtrBasePosInDataDescriptor] 43 .isa<LLVM::LLVMPointerType>() || 44 type.getBody()[kPtrBasePosInDataDescriptor] 45 .isa<LLVM::LLVMStructType>()) && 46 type.getBody()[kPtrPosInDataDescriptor].isa<LLVM::LLVMPointerType>() && 47 type.getBody()[kSizePosInDataDescriptor].isInteger(64)) 48 return true; 49 } 50 return false; 51 } 52 53 /// Builds IR inserting the base pointer value into the descriptor. 54 void DataDescriptor::setBasePointer(OpBuilder &builder, Location loc, 55 Value basePtr) { 56 setPtr(builder, loc, kPtrBasePosInDataDescriptor, basePtr); 57 } 58 59 /// Builds IR inserting the pointer value into the descriptor. 60 void DataDescriptor::setPointer(OpBuilder &builder, Location loc, Value ptr) { 61 setPtr(builder, loc, kPtrPosInDataDescriptor, ptr); 62 } 63 64 /// Builds IR inserting the size value into the descriptor. 65 void DataDescriptor::setSize(OpBuilder &builder, Location loc, Value size) { 66 setPtr(builder, loc, kSizePosInDataDescriptor, size); 67 } 68 69 //===----------------------------------------------------------------------===// 70 // Conversion patterns 71 //===----------------------------------------------------------------------===// 72 73 namespace { 74 75 template <typename Op> 76 class LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern<Op> { 77 using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern; 78 79 LogicalResult 80 matchAndRewrite(Op op, ArrayRef<Value> operands, 81 ConversionPatternRewriter &builder) const override { 82 Location loc = op.getLoc(); 83 TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); 84 85 unsigned numDataOperand = op.getNumDataOperands(); 86 87 // Keep the non data operands without modification. 88 auto nonDataOperands = 89 operands.take_front(operands.size() - numDataOperand); 90 SmallVector<Value> convertedOperands; 91 convertedOperands.append(nonDataOperands.begin(), nonDataOperands.end()); 92 93 // Go over the data operand and legalize them for translation. 94 for (unsigned idx = 0; idx < numDataOperand; ++idx) { 95 Value originalDataOperand = op.getDataOperand(idx); 96 97 // Traverse operands that were converted to MemRefDescriptors. 98 if (auto memRefType = 99 originalDataOperand.getType().dyn_cast<MemRefType>()) { 100 Type structType = converter->convertType(memRefType); 101 Value memRefDescriptor = builder 102 .create<LLVM::DialectCastOp>( 103 loc, structType, originalDataOperand) 104 .getResult(); 105 106 // Calculate the size of the memref and get the pointer to the allocated 107 // buffer. 108 SmallVector<Value> sizes; 109 SmallVector<Value> strides; 110 Value sizeBytes; 111 ConvertToLLVMPattern::getMemRefDescriptorSizes( 112 loc, memRefType, {}, builder, sizes, strides, sizeBytes); 113 MemRefDescriptor descriptor(memRefDescriptor); 114 Value dataPtr = descriptor.alignedPtr(builder, loc); 115 auto ptrType = descriptor.getElementPtrType(); 116 117 auto descr = DataDescriptor::undef(builder, loc, structType, ptrType); 118 descr.setBasePointer(builder, loc, memRefDescriptor); 119 descr.setPointer(builder, loc, dataPtr); 120 descr.setSize(builder, loc, sizeBytes); 121 convertedOperands.push_back(descr); 122 } else if (originalDataOperand.getType().isa<LLVM::LLVMPointerType>()) { 123 convertedOperands.push_back(originalDataOperand); 124 } else { 125 // Type not supported. 126 return builder.notifyMatchFailure(op, "unsupported type"); 127 } 128 } 129 130 builder.replaceOpWithNewOp<Op>(op, TypeRange(), convertedOperands, 131 op.getOperation()->getAttrs()); 132 133 return success(); 134 } 135 }; 136 } // namespace 137 138 void mlir::populateOpenACCToLLVMConversionPatterns( 139 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 140 patterns.add<LegalizeDataOpForLLVMTranslation<acc::DataOp>>(converter); 141 patterns.add<LegalizeDataOpForLLVMTranslation<acc::EnterDataOp>>(converter); 142 patterns.add<LegalizeDataOpForLLVMTranslation<acc::ExitDataOp>>(converter); 143 patterns.add<LegalizeDataOpForLLVMTranslation<acc::ParallelOp>>(converter); 144 patterns.add<LegalizeDataOpForLLVMTranslation<acc::UpdateOp>>(converter); 145 } 146 147 namespace { 148 struct ConvertOpenACCToLLVMPass 149 : public ConvertOpenACCToLLVMBase<ConvertOpenACCToLLVMPass> { 150 void runOnOperation() override; 151 }; 152 } // namespace 153 154 void ConvertOpenACCToLLVMPass::runOnOperation() { 155 auto op = getOperation(); 156 auto *context = op.getContext(); 157 158 // Convert to OpenACC operations with LLVM IR dialect 159 RewritePatternSet patterns(context); 160 LLVMTypeConverter converter(context); 161 populateOpenACCToLLVMConversionPatterns(converter, patterns); 162 163 ConversionTarget target(*context); 164 target.addLegalDialect<LLVM::LLVMDialect>(); 165 166 auto allDataOperandsAreConverted = [](ValueRange operands) { 167 for (Value operand : operands) { 168 if (!DataDescriptor::isValid(operand) && 169 !operand.getType().isa<LLVM::LLVMPointerType>()) 170 return false; 171 } 172 return true; 173 }; 174 175 target.addDynamicallyLegalOp<acc::DataOp>( 176 [allDataOperandsAreConverted](acc::DataOp op) { 177 return allDataOperandsAreConverted(op.copyOperands()) && 178 allDataOperandsAreConverted(op.copyinOperands()) && 179 allDataOperandsAreConverted(op.copyinReadonlyOperands()) && 180 allDataOperandsAreConverted(op.copyoutOperands()) && 181 allDataOperandsAreConverted(op.copyoutZeroOperands()) && 182 allDataOperandsAreConverted(op.createOperands()) && 183 allDataOperandsAreConverted(op.createZeroOperands()) && 184 allDataOperandsAreConverted(op.noCreateOperands()) && 185 allDataOperandsAreConverted(op.presentOperands()) && 186 allDataOperandsAreConverted(op.deviceptrOperands()) && 187 allDataOperandsAreConverted(op.attachOperands()); 188 }); 189 190 target.addDynamicallyLegalOp<acc::EnterDataOp>( 191 [allDataOperandsAreConverted](acc::EnterDataOp op) { 192 return allDataOperandsAreConverted(op.copyinOperands()) && 193 allDataOperandsAreConverted(op.createOperands()) && 194 allDataOperandsAreConverted(op.createZeroOperands()) && 195 allDataOperandsAreConverted(op.attachOperands()); 196 }); 197 198 target.addDynamicallyLegalOp<acc::ExitDataOp>( 199 [allDataOperandsAreConverted](acc::ExitDataOp op) { 200 return allDataOperandsAreConverted(op.copyoutOperands()) && 201 allDataOperandsAreConverted(op.deleteOperands()) && 202 allDataOperandsAreConverted(op.detachOperands()); 203 }); 204 205 target.addDynamicallyLegalOp<acc::ParallelOp>( 206 [allDataOperandsAreConverted](acc::ParallelOp op) { 207 return allDataOperandsAreConverted(op.reductionOperands()) && 208 allDataOperandsAreConverted(op.copyOperands()) && 209 allDataOperandsAreConverted(op.copyinOperands()) && 210 allDataOperandsAreConverted(op.copyinReadonlyOperands()) && 211 allDataOperandsAreConverted(op.copyoutOperands()) && 212 allDataOperandsAreConverted(op.copyoutZeroOperands()) && 213 allDataOperandsAreConverted(op.createOperands()) && 214 allDataOperandsAreConverted(op.createZeroOperands()) && 215 allDataOperandsAreConverted(op.noCreateOperands()) && 216 allDataOperandsAreConverted(op.presentOperands()) && 217 allDataOperandsAreConverted(op.devicePtrOperands()) && 218 allDataOperandsAreConverted(op.attachOperands()) && 219 allDataOperandsAreConverted(op.gangPrivateOperands()) && 220 allDataOperandsAreConverted(op.gangFirstPrivateOperands()); 221 }); 222 223 target.addDynamicallyLegalOp<acc::UpdateOp>( 224 [allDataOperandsAreConverted](acc::UpdateOp op) { 225 return allDataOperandsAreConverted(op.hostOperands()) && 226 allDataOperandsAreConverted(op.deviceOperands()); 227 }); 228 229 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 230 signalPassFailure(); 231 } 232 233 std::unique_ptr<OperationPass<ModuleOp>> 234 mlir::createConvertOpenACCToLLVMPass() { 235 return std::make_unique<ConvertOpenACCToLLVMPass>(); 236 } 237