1 //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 13 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 14 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" 15 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" 16 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 17 #include "mlir/Conversion/LLVMCommon/Pattern.h" 18 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 21 22 using namespace mlir; 23 24 namespace { 25 /// A pattern that converts the region arguments in a single-region OpenMP 26 /// operation to the LLVM dialect. The body of the region is not modified and is 27 /// expected to either be processed by the conversion infrastructure or already 28 /// contain ops compatible with LLVM dialect types. 29 template <typename OpType> 30 struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> { 31 using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern; 32 33 LogicalResult 34 matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor, 35 ConversionPatternRewriter &rewriter) const override { 36 auto newOp = rewriter.create<OpType>( 37 curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs()); 38 rewriter.inlineRegionBefore(curOp.region(), newOp.region(), 39 newOp.region().end()); 40 if (failed(rewriter.convertRegionTypes(&newOp.region(), 41 *this->getTypeConverter()))) 42 return failure(); 43 44 rewriter.eraseOp(curOp); 45 return success(); 46 } 47 }; 48 49 template <typename T> 50 struct RegionLessOpWithVarOperandsConversion 51 : public ConvertOpToLLVMPattern<T> { 52 using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern; 53 LogicalResult 54 matchAndRewrite(T curOp, typename T::Adaptor adaptor, 55 ConversionPatternRewriter &rewriter) const override { 56 TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); 57 SmallVector<Type> resTypes; 58 if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes))) 59 return failure(); 60 SmallVector<Value> convertedOperands; 61 assert(curOp.getNumVariableOperands() == 62 curOp.getOperation()->getNumOperands() && 63 "unexpected non-variable operands"); 64 for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) { 65 Value originalVariableOperand = curOp.getVariableOperand(idx); 66 if (!originalVariableOperand) 67 return failure(); 68 if (originalVariableOperand.getType().isa<MemRefType>()) { 69 // TODO: Support memref type in variable operands 70 return rewriter.notifyMatchFailure(curOp, 71 "memref is not supported yet"); 72 } 73 convertedOperands.emplace_back(adaptor.getOperands()[idx]); 74 } 75 rewriter.replaceOpWithNewOp<T>(curOp, resTypes, convertedOperands, 76 curOp->getAttrs()); 77 return success(); 78 } 79 }; 80 81 struct ReductionOpConversion : public ConvertOpToLLVMPattern<omp::ReductionOp> { 82 using ConvertOpToLLVMPattern<omp::ReductionOp>::ConvertOpToLLVMPattern; 83 LogicalResult 84 matchAndRewrite(omp::ReductionOp curOp, OpAdaptor adaptor, 85 ConversionPatternRewriter &rewriter) const override { 86 if (curOp.accumulator().getType().isa<MemRefType>()) { 87 // TODO: Support memref type in variable operands 88 return rewriter.notifyMatchFailure(curOp, "memref is not supported yet"); 89 } 90 rewriter.replaceOpWithNewOp<omp::ReductionOp>( 91 curOp, TypeRange(), adaptor.getOperands(), curOp->getAttrs()); 92 return success(); 93 } 94 }; 95 } // namespace 96 97 void mlir::configureOpenMPToLLVMConversionLegality( 98 ConversionTarget &target, LLVMTypeConverter &typeConverter) { 99 target.addDynamicallyLegalOp<mlir::omp::CriticalOp, mlir::omp::ParallelOp, 100 mlir::omp::WsLoopOp, mlir::omp::MasterOp, 101 mlir::omp::SectionsOp, mlir::omp::SingleOp>( 102 [&](Operation *op) { 103 return typeConverter.isLegal(&op->getRegion(0)) && 104 typeConverter.isLegal(op->getOperandTypes()) && 105 typeConverter.isLegal(op->getResultTypes()); 106 }); 107 target 108 .addDynamicallyLegalOp<mlir::omp::AtomicReadOp, mlir::omp::AtomicWriteOp, 109 mlir::omp::FlushOp, mlir::omp::ThreadprivateOp>( 110 [&](Operation *op) { 111 return typeConverter.isLegal(op->getOperandTypes()) && 112 typeConverter.isLegal(op->getResultTypes()); 113 }); 114 target.addDynamicallyLegalOp<mlir::omp::ReductionOp>([&](Operation *op) { 115 return typeConverter.isLegal(op->getOperandTypes()); 116 }); 117 } 118 119 void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, 120 RewritePatternSet &patterns) { 121 patterns.add< 122 ReductionOpConversion, RegionOpConversion<omp::CriticalOp>, 123 RegionOpConversion<omp::MasterOp>, ReductionOpConversion, 124 RegionOpConversion<omp::MasterOp>, RegionOpConversion<omp::ParallelOp>, 125 RegionOpConversion<omp::WsLoopOp>, RegionOpConversion<omp::SectionsOp>, 126 RegionOpConversion<omp::SingleOp>, 127 RegionLessOpWithVarOperandsConversion<omp::AtomicReadOp>, 128 RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>, 129 RegionLessOpWithVarOperandsConversion<omp::FlushOp>, 130 RegionLessOpWithVarOperandsConversion<omp::ThreadprivateOp>>(converter); 131 } 132 133 namespace { 134 struct ConvertOpenMPToLLVMPass 135 : public ConvertOpenMPToLLVMBase<ConvertOpenMPToLLVMPass> { 136 void runOnOperation() override; 137 }; 138 } // namespace 139 140 void ConvertOpenMPToLLVMPass::runOnOperation() { 141 auto module = getOperation(); 142 143 // Convert to OpenMP operations with LLVM IR dialect 144 RewritePatternSet patterns(&getContext()); 145 LLVMTypeConverter converter(&getContext()); 146 arith::populateArithmeticToLLVMConversionPatterns(converter, patterns); 147 cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); 148 populateMemRefToLLVMConversionPatterns(converter, patterns); 149 populateFuncToLLVMConversionPatterns(converter, patterns); 150 populateOpenMPToLLVMConversionPatterns(converter, patterns); 151 152 LLVMConversionTarget target(getContext()); 153 target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp, 154 omp::BarrierOp, omp::TaskwaitOp>(); 155 configureOpenMPToLLVMConversionLegality(target, converter); 156 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 157 signalPassFailure(); 158 } 159 160 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenMPToLLVMPass() { 161 return std::make_unique<ConvertOpenMPToLLVMPass>(); 162 } 163