1119545f4SAlex Zinenko //===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===// 2119545f4SAlex Zinenko // 3119545f4SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4119545f4SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 5119545f4SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6119545f4SAlex Zinenko // 7119545f4SAlex Zinenko //===----------------------------------------------------------------------===// 8119545f4SAlex Zinenko // 9119545f4SAlex Zinenko // This file implements a pass to convert scf.parallel operations into OpenMP 10119545f4SAlex Zinenko // parallel loops. 11119545f4SAlex Zinenko // 12119545f4SAlex Zinenko //===----------------------------------------------------------------------===// 13119545f4SAlex Zinenko 14119545f4SAlex Zinenko #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" 15119545f4SAlex Zinenko #include "../PassDetail.h" 162a876a71SDiego Caballero #include "mlir/Analysis/LoopAnalysis.h" 17*a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 181ce752b7SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19119545f4SAlex Zinenko #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 20119545f4SAlex Zinenko #include "mlir/Dialect/SCF/SCF.h" 211ce752b7SAlex Zinenko #include "mlir/Dialect/StandardOps/IR/Ops.h" 221ce752b7SAlex Zinenko #include "mlir/IR/ImplicitLocOpBuilder.h" 231ce752b7SAlex Zinenko #include "mlir/IR/SymbolTable.h" 24119545f4SAlex Zinenko #include "mlir/Transforms/DialectConversion.h" 25119545f4SAlex Zinenko 26119545f4SAlex Zinenko using namespace mlir; 27119545f4SAlex Zinenko 281ce752b7SAlex Zinenko /// Matches a block containing a "simple" reduction. The expected shape of the 291ce752b7SAlex Zinenko /// block is as follows. 301ce752b7SAlex Zinenko /// 311ce752b7SAlex Zinenko /// ^bb(%arg0, %arg1): 321ce752b7SAlex Zinenko /// %0 = OpTy(%arg0, %arg1) 331ce752b7SAlex Zinenko /// scf.reduce.return %0 341ce752b7SAlex Zinenko template <typename... OpTy> 351ce752b7SAlex Zinenko static bool matchSimpleReduction(Block &block) { 361ce752b7SAlex Zinenko if (block.empty() || llvm::hasSingleElement(block) || 371ce752b7SAlex Zinenko std::next(block.begin(), 2) != block.end()) 381ce752b7SAlex Zinenko return false; 392a876a71SDiego Caballero 402a876a71SDiego Caballero if (block.getNumArguments() != 2) 412a876a71SDiego Caballero return false; 422a876a71SDiego Caballero 432a876a71SDiego Caballero SmallVector<Operation *, 4> combinerOps; 442a876a71SDiego Caballero Value reducedVal = matchReduction({block.getArguments()[1]}, 452a876a71SDiego Caballero /*redPos=*/0, combinerOps); 462a876a71SDiego Caballero 472a876a71SDiego Caballero if (!reducedVal || !reducedVal.isa<BlockArgument>() || 482a876a71SDiego Caballero combinerOps.size() != 1) 492a876a71SDiego Caballero return false; 502a876a71SDiego Caballero 512a876a71SDiego Caballero return isa<OpTy...>(combinerOps[0]) && 521ce752b7SAlex Zinenko isa<scf::ReduceReturnOp>(block.back()) && 532a876a71SDiego Caballero block.front().getOperands() == block.getArguments(); 541ce752b7SAlex Zinenko } 551ce752b7SAlex Zinenko 561ce752b7SAlex Zinenko /// Matches a block containing a select-based min/max reduction. The types of 571ce752b7SAlex Zinenko /// select and compare operations are provided as template arguments. The 581ce752b7SAlex Zinenko /// comparison predicates suitable for min and max are provided as function 591ce752b7SAlex Zinenko /// arguments. If a reduction is matched, `ifMin` will be set if the reduction 601ce752b7SAlex Zinenko /// compute the minimum and unset if it computes the maximum, otherwise it 611ce752b7SAlex Zinenko /// remains unmodified. The expected shape of the block is as follows. 621ce752b7SAlex Zinenko /// 631ce752b7SAlex Zinenko /// ^bb(%arg0, %arg1): 641ce752b7SAlex Zinenko /// %0 = CompareOpTy(<one-of-predicates>, %arg0, %arg1) 651ce752b7SAlex Zinenko /// %1 = SelectOpTy(%0, %arg0, %arg1) // %arg0, %arg1 may be swapped here. 661ce752b7SAlex Zinenko /// scf.reduce.return %1 671ce752b7SAlex Zinenko template < 681ce752b7SAlex Zinenko typename CompareOpTy, typename SelectOpTy, 691ce752b7SAlex Zinenko typename Predicate = decltype(std::declval<CompareOpTy>().predicate())> 701ce752b7SAlex Zinenko static bool 711ce752b7SAlex Zinenko matchSelectReduction(Block &block, ArrayRef<Predicate> lessThanPredicates, 721ce752b7SAlex Zinenko ArrayRef<Predicate> greaterThanPredicates, bool &isMin) { 731ce752b7SAlex Zinenko static_assert(llvm::is_one_of<SelectOpTy, SelectOp, LLVM::SelectOp>::value, 741ce752b7SAlex Zinenko "only std and llvm select ops are supported"); 751ce752b7SAlex Zinenko 761ce752b7SAlex Zinenko // Expect exactly three operations in the block. 771ce752b7SAlex Zinenko if (block.empty() || llvm::hasSingleElement(block) || 781ce752b7SAlex Zinenko std::next(block.begin(), 2) == block.end() || 791ce752b7SAlex Zinenko std::next(block.begin(), 3) != block.end()) 801ce752b7SAlex Zinenko return false; 811ce752b7SAlex Zinenko 821ce752b7SAlex Zinenko // Check op kinds. 831ce752b7SAlex Zinenko auto compare = dyn_cast<CompareOpTy>(block.front()); 841ce752b7SAlex Zinenko auto select = dyn_cast<SelectOpTy>(block.front().getNextNode()); 851ce752b7SAlex Zinenko auto terminator = dyn_cast<scf::ReduceReturnOp>(block.back()); 861ce752b7SAlex Zinenko if (!compare || !select || !terminator) 871ce752b7SAlex Zinenko return false; 881ce752b7SAlex Zinenko 891ce752b7SAlex Zinenko // Block arguments must be compared. 901ce752b7SAlex Zinenko if (compare->getOperands() != block.getArguments()) 911ce752b7SAlex Zinenko return false; 921ce752b7SAlex Zinenko 931ce752b7SAlex Zinenko // Detect whether the comparison is less-than or greater-than, otherwise bail. 941ce752b7SAlex Zinenko bool isLess; 951ce752b7SAlex Zinenko if (llvm::find(lessThanPredicates, compare.predicate()) != 961ce752b7SAlex Zinenko lessThanPredicates.end()) { 971ce752b7SAlex Zinenko isLess = true; 981ce752b7SAlex Zinenko } else if (llvm::find(greaterThanPredicates, compare.predicate()) != 991ce752b7SAlex Zinenko greaterThanPredicates.end()) { 1001ce752b7SAlex Zinenko isLess = false; 1011ce752b7SAlex Zinenko } else { 1021ce752b7SAlex Zinenko return false; 1031ce752b7SAlex Zinenko } 1041ce752b7SAlex Zinenko 1051ce752b7SAlex Zinenko if (select.condition() != compare.getResult()) 1061ce752b7SAlex Zinenko return false; 1071ce752b7SAlex Zinenko 1081ce752b7SAlex Zinenko // Detect if the operands are swapped between cmpf and select. Match the 1091ce752b7SAlex Zinenko // comparison type with the requested type or with the opposite of the 1101ce752b7SAlex Zinenko // requested type if the operands are swapped. Use generic accessors because 1111ce752b7SAlex Zinenko // std and LLVM versions of select have different operand names but identical 1121ce752b7SAlex Zinenko // positions. 1131ce752b7SAlex Zinenko constexpr unsigned kTrueValue = 1; 1141ce752b7SAlex Zinenko constexpr unsigned kFalseValue = 2; 1151ce752b7SAlex Zinenko bool sameOperands = select.getOperand(kTrueValue) == compare.lhs() && 1161ce752b7SAlex Zinenko select.getOperand(kFalseValue) == compare.rhs(); 1171ce752b7SAlex Zinenko bool swappedOperands = select.getOperand(kTrueValue) == compare.rhs() && 1181ce752b7SAlex Zinenko select.getOperand(kFalseValue) == compare.lhs(); 1191ce752b7SAlex Zinenko if (!sameOperands && !swappedOperands) 1201ce752b7SAlex Zinenko return false; 1211ce752b7SAlex Zinenko 1221ce752b7SAlex Zinenko if (select.getResult() != terminator.result()) 1231ce752b7SAlex Zinenko return false; 1241ce752b7SAlex Zinenko 1251ce752b7SAlex Zinenko // The reduction is a min if it uses less-than predicates with same operands 1261ce752b7SAlex Zinenko // or greather-than predicates with swapped operands. Similarly for max. 1271ce752b7SAlex Zinenko isMin = (isLess && sameOperands) || (!isLess && swappedOperands); 1281ce752b7SAlex Zinenko return isMin || (isLess & swappedOperands) || (!isLess && sameOperands); 1291ce752b7SAlex Zinenko } 1301ce752b7SAlex Zinenko 1311ce752b7SAlex Zinenko /// Returns the float semantics for the given float type. 1321ce752b7SAlex Zinenko static const llvm::fltSemantics &fltSemanticsForType(FloatType type) { 1331ce752b7SAlex Zinenko if (type.isF16()) 1341ce752b7SAlex Zinenko return llvm::APFloat::IEEEhalf(); 1351ce752b7SAlex Zinenko if (type.isF32()) 1361ce752b7SAlex Zinenko return llvm::APFloat::IEEEsingle(); 1371ce752b7SAlex Zinenko if (type.isF64()) 1381ce752b7SAlex Zinenko return llvm::APFloat::IEEEdouble(); 1391ce752b7SAlex Zinenko if (type.isF128()) 1401ce752b7SAlex Zinenko return llvm::APFloat::IEEEquad(); 1411ce752b7SAlex Zinenko if (type.isBF16()) 1421ce752b7SAlex Zinenko return llvm::APFloat::BFloat(); 1431ce752b7SAlex Zinenko if (type.isF80()) 1441ce752b7SAlex Zinenko return llvm::APFloat::x87DoubleExtended(); 1451ce752b7SAlex Zinenko llvm_unreachable("unknown float type"); 1461ce752b7SAlex Zinenko } 1471ce752b7SAlex Zinenko 1481ce752b7SAlex Zinenko /// Returns an attribute with the minimum (if `min` is set) or the maximum value 1491ce752b7SAlex Zinenko /// (otherwise) for the given float type. 1501ce752b7SAlex Zinenko static Attribute minMaxValueForFloat(Type type, bool min) { 1511ce752b7SAlex Zinenko auto fltType = type.cast<FloatType>(); 1521ce752b7SAlex Zinenko return FloatAttr::get( 1531ce752b7SAlex Zinenko type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min)); 1541ce752b7SAlex Zinenko } 1551ce752b7SAlex Zinenko 1561ce752b7SAlex Zinenko /// Returns an attribute with the signed integer minimum (if `min` is set) or 1571ce752b7SAlex Zinenko /// the maximum value (otherwise) for the given integer type, regardless of its 1581ce752b7SAlex Zinenko /// signedness semantics (only the width is considered). 1591ce752b7SAlex Zinenko static Attribute minMaxValueForSignedInt(Type type, bool min) { 1601ce752b7SAlex Zinenko auto intType = type.cast<IntegerType>(); 1611ce752b7SAlex Zinenko unsigned bitwidth = intType.getWidth(); 1621ce752b7SAlex Zinenko return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth) 1631ce752b7SAlex Zinenko : llvm::APInt::getSignedMaxValue(bitwidth)); 1641ce752b7SAlex Zinenko } 1651ce752b7SAlex Zinenko 1661ce752b7SAlex Zinenko /// Returns an attribute with the unsigned integer minimum (if `min` is set) or 1671ce752b7SAlex Zinenko /// the maximum value (otherwise) for the given integer type, regardless of its 1681ce752b7SAlex Zinenko /// signedness semantics (only the width is considered). 1691ce752b7SAlex Zinenko static Attribute minMaxValueForUnsignedInt(Type type, bool min) { 1701ce752b7SAlex Zinenko auto intType = type.cast<IntegerType>(); 1711ce752b7SAlex Zinenko unsigned bitwidth = intType.getWidth(); 1721ce752b7SAlex Zinenko return IntegerAttr::get(type, min ? llvm::APInt::getNullValue(bitwidth) 1731ce752b7SAlex Zinenko : llvm::APInt::getAllOnesValue(bitwidth)); 1741ce752b7SAlex Zinenko } 1751ce752b7SAlex Zinenko 1761ce752b7SAlex Zinenko /// Creates an OpenMP reduction declaration and inserts it into the provided 1771ce752b7SAlex Zinenko /// symbol table. The declaration has a constant initializer with the neutral 1781ce752b7SAlex Zinenko /// value `initValue`, and the reduction combiner carried over from `reduce`. 1791ce752b7SAlex Zinenko static omp::ReductionDeclareOp createDecl(PatternRewriter &builder, 1801ce752b7SAlex Zinenko SymbolTable &symbolTable, 1811ce752b7SAlex Zinenko scf::ReduceOp reduce, 1821ce752b7SAlex Zinenko Attribute initValue) { 1831ce752b7SAlex Zinenko OpBuilder::InsertionGuard guard(builder); 1841ce752b7SAlex Zinenko auto decl = builder.create<omp::ReductionDeclareOp>( 1851ce752b7SAlex Zinenko reduce.getLoc(), "__scf_reduction", reduce.operand().getType()); 1861ce752b7SAlex Zinenko symbolTable.insert(decl); 1871ce752b7SAlex Zinenko 1881ce752b7SAlex Zinenko Type type = reduce.operand().getType(); 1891ce752b7SAlex Zinenko builder.createBlock(&decl.initializerRegion(), decl.initializerRegion().end(), 1901ce752b7SAlex Zinenko {type}); 1911ce752b7SAlex Zinenko builder.setInsertionPointToEnd(&decl.initializerRegion().back()); 1921ce752b7SAlex Zinenko Value init = 1931ce752b7SAlex Zinenko builder.create<LLVM::ConstantOp>(reduce.getLoc(), type, initValue); 1941ce752b7SAlex Zinenko builder.create<omp::YieldOp>(reduce.getLoc(), init); 1951ce752b7SAlex Zinenko 1961ce752b7SAlex Zinenko Operation *terminator = &reduce.getRegion().front().back(); 1971ce752b7SAlex Zinenko assert(isa<scf::ReduceReturnOp>(terminator) && 1981ce752b7SAlex Zinenko "expected reduce op to be terminated by redure return"); 1991ce752b7SAlex Zinenko builder.setInsertionPoint(terminator); 2001ce752b7SAlex Zinenko builder.replaceOpWithNewOp<omp::YieldOp>(terminator, 2011ce752b7SAlex Zinenko terminator->getOperands()); 2021ce752b7SAlex Zinenko builder.inlineRegionBefore(reduce.getRegion(), decl.reductionRegion(), 2031ce752b7SAlex Zinenko decl.reductionRegion().end()); 2041ce752b7SAlex Zinenko return decl; 2051ce752b7SAlex Zinenko } 2061ce752b7SAlex Zinenko 2071ce752b7SAlex Zinenko /// Adds an atomic reduction combiner to the given OpenMP reduction declaration 2081ce752b7SAlex Zinenko /// using llvm.atomicrmw of the given kind. 2091ce752b7SAlex Zinenko static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder, 2101ce752b7SAlex Zinenko LLVM::AtomicBinOp atomicKind, 2111ce752b7SAlex Zinenko omp::ReductionDeclareOp decl, 2121ce752b7SAlex Zinenko scf::ReduceOp reduce) { 2131ce752b7SAlex Zinenko OpBuilder::InsertionGuard guard(builder); 2141ce752b7SAlex Zinenko Type type = reduce.operand().getType(); 2151ce752b7SAlex Zinenko Type ptrType = LLVM::LLVMPointerType::get(type); 2161ce752b7SAlex Zinenko builder.createBlock(&decl.atomicReductionRegion(), 2171ce752b7SAlex Zinenko decl.atomicReductionRegion().end(), {ptrType, ptrType}); 2181ce752b7SAlex Zinenko Block *atomicBlock = &decl.atomicReductionRegion().back(); 2191ce752b7SAlex Zinenko builder.setInsertionPointToEnd(atomicBlock); 2201ce752b7SAlex Zinenko Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(), 2211ce752b7SAlex Zinenko atomicBlock->getArgument(1)); 2221ce752b7SAlex Zinenko builder.create<LLVM::AtomicRMWOp>(reduce.getLoc(), type, atomicKind, 2231ce752b7SAlex Zinenko atomicBlock->getArgument(0), loaded, 2241ce752b7SAlex Zinenko LLVM::AtomicOrdering::monotonic); 2251ce752b7SAlex Zinenko builder.create<omp::YieldOp>(reduce.getLoc(), ArrayRef<Value>()); 2261ce752b7SAlex Zinenko return decl; 2271ce752b7SAlex Zinenko } 2281ce752b7SAlex Zinenko 2291ce752b7SAlex Zinenko /// Creates an OpenMP reduction declaration that corresponds to the given SCF 2301ce752b7SAlex Zinenko /// reduction and returns it. Recognizes common reductions in order to identify 2311ce752b7SAlex Zinenko /// the neutral value, necessary for the OpenMP declaration. If the reduction 2321ce752b7SAlex Zinenko /// cannot be recognized, returns null. 2331ce752b7SAlex Zinenko static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder, 2341ce752b7SAlex Zinenko scf::ReduceOp reduce) { 2351ce752b7SAlex Zinenko Operation *container = SymbolTable::getNearestSymbolTable(reduce); 2361ce752b7SAlex Zinenko SymbolTable symbolTable(container); 2371ce752b7SAlex Zinenko 2381ce752b7SAlex Zinenko // Insert reduction declarations in the symbol-table ancestor before the 2391ce752b7SAlex Zinenko // ancestor of the current insertion point. 2401ce752b7SAlex Zinenko Operation *insertionPoint = reduce; 2411ce752b7SAlex Zinenko while (insertionPoint->getParentOp() != container) 2421ce752b7SAlex Zinenko insertionPoint = insertionPoint->getParentOp(); 2431ce752b7SAlex Zinenko OpBuilder::InsertionGuard guard(builder); 2441ce752b7SAlex Zinenko builder.setInsertionPoint(insertionPoint); 2451ce752b7SAlex Zinenko 2461ce752b7SAlex Zinenko assert(llvm::hasSingleElement(reduce.getRegion()) && 2471ce752b7SAlex Zinenko "expected reduction region to have a single element"); 2481ce752b7SAlex Zinenko 2491ce752b7SAlex Zinenko // Match simple binary reductions that can be expressed with atomicrmw. 2501ce752b7SAlex Zinenko Type type = reduce.operand().getType(); 2511ce752b7SAlex Zinenko Block &reduction = reduce.getRegion().front(); 252*a54f4eaeSMogball if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) { 2531ce752b7SAlex Zinenko omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, 2541ce752b7SAlex Zinenko builder.getFloatAttr(type, 0.0)); 2551ce752b7SAlex Zinenko return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce); 2561ce752b7SAlex Zinenko } 257*a54f4eaeSMogball if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) { 2581ce752b7SAlex Zinenko omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, 2591ce752b7SAlex Zinenko builder.getIntegerAttr(type, 0)); 2601ce752b7SAlex Zinenko return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce); 2611ce752b7SAlex Zinenko } 262*a54f4eaeSMogball if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) { 2631ce752b7SAlex Zinenko omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, 2641ce752b7SAlex Zinenko builder.getIntegerAttr(type, 0)); 2651ce752b7SAlex Zinenko return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce); 2661ce752b7SAlex Zinenko } 267*a54f4eaeSMogball if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) { 2681ce752b7SAlex Zinenko omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, 2691ce752b7SAlex Zinenko builder.getIntegerAttr(type, 0)); 2701ce752b7SAlex Zinenko return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce); 2711ce752b7SAlex Zinenko } 272*a54f4eaeSMogball if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) { 2731ce752b7SAlex Zinenko omp::ReductionDeclareOp decl = createDecl( 2741ce752b7SAlex Zinenko builder, symbolTable, reduce, 2751ce752b7SAlex Zinenko builder.getIntegerAttr( 2761ce752b7SAlex Zinenko type, llvm::APInt::getAllOnesValue(type.getIntOrFloatBitWidth()))); 2771ce752b7SAlex Zinenko return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce); 2781ce752b7SAlex Zinenko } 2791ce752b7SAlex Zinenko 2801ce752b7SAlex Zinenko // Match simple binary reductions that cannot be expressed with atomicrmw. 2811ce752b7SAlex Zinenko // TODO: add atomic region using cmpxchg (which needs atomic load to be 2821ce752b7SAlex Zinenko // available as an op). 283*a54f4eaeSMogball if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) { 2841ce752b7SAlex Zinenko return createDecl(builder, symbolTable, reduce, 2851ce752b7SAlex Zinenko builder.getFloatAttr(type, 1.0)); 2861ce752b7SAlex Zinenko } 2871ce752b7SAlex Zinenko 2881ce752b7SAlex Zinenko // Match select-based min/max reductions. 2891ce752b7SAlex Zinenko bool isMin; 290*a54f4eaeSMogball if (matchSelectReduction<arith::CmpFOp, SelectOp>( 291*a54f4eaeSMogball reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE}, 292*a54f4eaeSMogball {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) || 2931ce752b7SAlex Zinenko matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>( 2941ce752b7SAlex Zinenko reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole}, 2951ce752b7SAlex Zinenko {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) { 2961ce752b7SAlex Zinenko return createDecl(builder, symbolTable, reduce, 2971ce752b7SAlex Zinenko minMaxValueForFloat(type, !isMin)); 2981ce752b7SAlex Zinenko } 299*a54f4eaeSMogball if (matchSelectReduction<arith::CmpIOp, SelectOp>( 300*a54f4eaeSMogball reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle}, 301*a54f4eaeSMogball {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) || 3021ce752b7SAlex Zinenko matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>( 3031ce752b7SAlex Zinenko reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle}, 3041ce752b7SAlex Zinenko {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) { 3051ce752b7SAlex Zinenko omp::ReductionDeclareOp decl = createDecl( 3061ce752b7SAlex Zinenko builder, symbolTable, reduce, minMaxValueForSignedInt(type, !isMin)); 3071ce752b7SAlex Zinenko return addAtomicRMW(builder, 3081ce752b7SAlex Zinenko isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max, 3091ce752b7SAlex Zinenko decl, reduce); 3101ce752b7SAlex Zinenko } 311*a54f4eaeSMogball if (matchSelectReduction<arith::CmpIOp, SelectOp>( 312*a54f4eaeSMogball reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule}, 313*a54f4eaeSMogball {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) || 3141ce752b7SAlex Zinenko matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>( 3151ce752b7SAlex Zinenko reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule}, 3161ce752b7SAlex Zinenko {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) { 3171ce752b7SAlex Zinenko omp::ReductionDeclareOp decl = createDecl( 3181ce752b7SAlex Zinenko builder, symbolTable, reduce, minMaxValueForUnsignedInt(type, !isMin)); 3191ce752b7SAlex Zinenko return addAtomicRMW( 3201ce752b7SAlex Zinenko builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax, 3211ce752b7SAlex Zinenko decl, reduce); 3221ce752b7SAlex Zinenko } 3231ce752b7SAlex Zinenko 3241ce752b7SAlex Zinenko return nullptr; 3251ce752b7SAlex Zinenko } 3261ce752b7SAlex Zinenko 327119545f4SAlex Zinenko namespace { 328119545f4SAlex Zinenko 329119545f4SAlex Zinenko struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { 330119545f4SAlex Zinenko using OpRewritePattern<scf::ParallelOp>::OpRewritePattern; 331119545f4SAlex Zinenko 332119545f4SAlex Zinenko LogicalResult matchAndRewrite(scf::ParallelOp parallelOp, 333119545f4SAlex Zinenko PatternRewriter &rewriter) const override { 334119545f4SAlex Zinenko // Replace SCF yield with OpenMP yield. 335119545f4SAlex Zinenko { 336119545f4SAlex Zinenko OpBuilder::InsertionGuard guard(rewriter); 337119545f4SAlex Zinenko rewriter.setInsertionPointToEnd(parallelOp.getBody()); 338119545f4SAlex Zinenko assert(llvm::hasSingleElement(parallelOp.region()) && 339119545f4SAlex Zinenko "expected scf.parallel to have one block"); 340119545f4SAlex Zinenko rewriter.replaceOpWithNewOp<omp::YieldOp>( 341119545f4SAlex Zinenko parallelOp.getBody()->getTerminator(), ValueRange()); 342119545f4SAlex Zinenko } 343119545f4SAlex Zinenko 3441ce752b7SAlex Zinenko // Declare reductions. 3451ce752b7SAlex Zinenko // TODO: consider checking it here is already a compatible reduction 3461ce752b7SAlex Zinenko // declaration and use it instead of redeclaring. 3471ce752b7SAlex Zinenko SmallVector<Attribute> reductionDeclSymbols; 3481ce752b7SAlex Zinenko for (auto reduce : parallelOp.getOps<scf::ReduceOp>()) { 3491ce752b7SAlex Zinenko omp::ReductionDeclareOp decl = declareReduction(rewriter, reduce); 3501ce752b7SAlex Zinenko if (!decl) 3511ce752b7SAlex Zinenko return failure(); 3521ce752b7SAlex Zinenko reductionDeclSymbols.push_back( 3531ce752b7SAlex Zinenko SymbolRefAttr::get(rewriter.getContext(), decl.sym_name())); 3541ce752b7SAlex Zinenko } 3551ce752b7SAlex Zinenko 3561ce752b7SAlex Zinenko // Allocate reduction variables. Make sure the we don't overflow the stack 3571ce752b7SAlex Zinenko // with local `alloca`s by saving and restoring the stack pointer. 3581ce752b7SAlex Zinenko Location loc = parallelOp.getLoc(); 3591ce752b7SAlex Zinenko Value one = rewriter.create<LLVM::ConstantOp>( 3601ce752b7SAlex Zinenko loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1)); 3611ce752b7SAlex Zinenko SmallVector<Value> reductionVariables; 3621ce752b7SAlex Zinenko reductionVariables.reserve(parallelOp.getNumReductions()); 3631ce752b7SAlex Zinenko Value token = rewriter.create<LLVM::StackSaveOp>( 3641ce752b7SAlex Zinenko loc, LLVM::LLVMPointerType::get(rewriter.getIntegerType(8))); 3651ce752b7SAlex Zinenko for (Value init : parallelOp.initVals()) { 3661ce752b7SAlex Zinenko assert((LLVM::isCompatibleType(init.getType()) || 3671ce752b7SAlex Zinenko init.getType().isa<LLVM::PointerElementTypeInterface>()) && 3681ce752b7SAlex Zinenko "cannot create a reduction variable if the type is not an LLVM " 3691ce752b7SAlex Zinenko "pointer element"); 3701ce752b7SAlex Zinenko Value storage = rewriter.create<LLVM::AllocaOp>( 3711ce752b7SAlex Zinenko loc, LLVM::LLVMPointerType::get(init.getType()), one, 0); 3721ce752b7SAlex Zinenko rewriter.create<LLVM::StoreOp>(loc, init, storage); 3731ce752b7SAlex Zinenko reductionVariables.push_back(storage); 3741ce752b7SAlex Zinenko } 3751ce752b7SAlex Zinenko 3761ce752b7SAlex Zinenko // Replace the reduction operations contained in this loop. Must be done 3771ce752b7SAlex Zinenko // here rather than in a separate pattern to have access to the list of 3781ce752b7SAlex Zinenko // reduction variables. 3791ce752b7SAlex Zinenko for (auto pair : 3801ce752b7SAlex Zinenko llvm::zip(parallelOp.getOps<scf::ReduceOp>(), reductionVariables)) { 3811ce752b7SAlex Zinenko OpBuilder::InsertionGuard guard(rewriter); 3821ce752b7SAlex Zinenko scf::ReduceOp reduceOp = std::get<0>(pair); 3831ce752b7SAlex Zinenko rewriter.setInsertionPoint(reduceOp); 3841ce752b7SAlex Zinenko rewriter.replaceOpWithNewOp<omp::ReductionOp>( 3851ce752b7SAlex Zinenko reduceOp, reduceOp.operand(), std::get<1>(pair)); 3861ce752b7SAlex Zinenko } 3871ce752b7SAlex Zinenko 3881ce752b7SAlex Zinenko // Create the parallel wrapper. 3891ce752b7SAlex Zinenko auto ompParallel = rewriter.create<omp::ParallelOp>(loc); 3901ce752b7SAlex Zinenko { 3911ce752b7SAlex Zinenko OpBuilder::InsertionGuard guard(rewriter); 3921ce752b7SAlex Zinenko rewriter.createBlock(&ompParallel.region()); 3931ce752b7SAlex Zinenko 3941ce752b7SAlex Zinenko // Replace SCF yield with OpenMP yield. 3951ce752b7SAlex Zinenko { 3961ce752b7SAlex Zinenko OpBuilder::InsertionGuard innerGuard(rewriter); 3971ce752b7SAlex Zinenko rewriter.setInsertionPointToEnd(parallelOp.getBody()); 3981ce752b7SAlex Zinenko assert(llvm::hasSingleElement(parallelOp.region()) && 3991ce752b7SAlex Zinenko "expected scf.parallel to have one block"); 4001ce752b7SAlex Zinenko rewriter.replaceOpWithNewOp<omp::YieldOp>( 4011ce752b7SAlex Zinenko parallelOp.getBody()->getTerminator(), ValueRange()); 4021ce752b7SAlex Zinenko } 4031ce752b7SAlex Zinenko 404119545f4SAlex Zinenko // Replace the loop. 405119545f4SAlex Zinenko auto loop = rewriter.create<omp::WsLoopOp>( 406119545f4SAlex Zinenko parallelOp.getLoc(), parallelOp.lowerBound(), parallelOp.upperBound(), 407119545f4SAlex Zinenko parallelOp.step()); 4081ce752b7SAlex Zinenko rewriter.create<omp::TerminatorOp>(loc); 4091ce752b7SAlex Zinenko 410119545f4SAlex Zinenko rewriter.inlineRegionBefore(parallelOp.region(), loop.region(), 411119545f4SAlex Zinenko loop.region().begin()); 4121ce752b7SAlex Zinenko if (!reductionVariables.empty()) { 4131ce752b7SAlex Zinenko loop.reductionsAttr( 4141ce752b7SAlex Zinenko ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols)); 4151ce752b7SAlex Zinenko loop.reduction_varsMutable().append(reductionVariables); 4161ce752b7SAlex Zinenko } 4171ce752b7SAlex Zinenko } 418973cb2c3SWilliam S. Moses 4191ce752b7SAlex Zinenko // Load loop results. 4201ce752b7SAlex Zinenko SmallVector<Value> results; 4211ce752b7SAlex Zinenko results.reserve(reductionVariables.size()); 4221ce752b7SAlex Zinenko for (Value variable : reductionVariables) { 4231ce752b7SAlex Zinenko Value res = rewriter.create<LLVM::LoadOp>(loc, variable); 4241ce752b7SAlex Zinenko results.push_back(res); 4251ce752b7SAlex Zinenko } 4261ce752b7SAlex Zinenko rewriter.replaceOp(parallelOp, results); 4271ce752b7SAlex Zinenko 4281ce752b7SAlex Zinenko rewriter.create<LLVM::StackRestoreOp>(loc, token); 429119545f4SAlex Zinenko return success(); 430119545f4SAlex Zinenko } 431119545f4SAlex Zinenko }; 432119545f4SAlex Zinenko 433119545f4SAlex Zinenko /// Applies the conversion patterns in the given function. 4341ce752b7SAlex Zinenko static LogicalResult applyPatterns(ModuleOp module) { 4351ce752b7SAlex Zinenko ConversionTarget target(*module.getContext()); 4361ce752b7SAlex Zinenko target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(); 4371ce752b7SAlex Zinenko target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect>(); 438119545f4SAlex Zinenko 4391ce752b7SAlex Zinenko RewritePatternSet patterns(module.getContext()); 4401ce752b7SAlex Zinenko patterns.add<ParallelOpLowering>(module.getContext()); 44179d7f618SChris Lattner FrozenRewritePatternSet frozen(std::move(patterns)); 4421ce752b7SAlex Zinenko return applyPartialConversion(module, target, frozen); 443119545f4SAlex Zinenko } 444119545f4SAlex Zinenko 445119545f4SAlex Zinenko /// A pass converting SCF operations to OpenMP operations. 446119545f4SAlex Zinenko struct SCFToOpenMPPass : public ConvertSCFToOpenMPBase<SCFToOpenMPPass> { 447119545f4SAlex Zinenko /// Pass entry point. 4481ce752b7SAlex Zinenko void runOnOperation() override { 4491ce752b7SAlex Zinenko if (failed(applyPatterns(getOperation()))) 450119545f4SAlex Zinenko signalPassFailure(); 451119545f4SAlex Zinenko } 452119545f4SAlex Zinenko }; 453119545f4SAlex Zinenko 454119545f4SAlex Zinenko } // end namespace 455119545f4SAlex Zinenko 4561ce752b7SAlex Zinenko std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertSCFToOpenMPPass() { 457119545f4SAlex Zinenko return std::make_unique<SCFToOpenMPPass>(); 458119545f4SAlex Zinenko } 459