1119545f4SAlex Zinenko //===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===// 2119545f4SAlex Zinenko // 3119545f4SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4119545f4SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 5119545f4SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6119545f4SAlex Zinenko // 7119545f4SAlex Zinenko //===----------------------------------------------------------------------===// 8119545f4SAlex Zinenko // 9119545f4SAlex Zinenko // This file implements a pass to convert scf.parallel operations into OpenMP 10119545f4SAlex Zinenko // parallel loops. 11119545f4SAlex Zinenko // 12119545f4SAlex Zinenko //===----------------------------------------------------------------------===// 13119545f4SAlex Zinenko 14119545f4SAlex Zinenko #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" 15119545f4SAlex Zinenko #include "../PassDetail.h" 16755dc07dSRiver Riddle #include "mlir/Analysis/SliceAnalysis.h" 17755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 18a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 191ce752b7SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 2078fb4f9dSWilliam S. Moses #include "mlir/Dialect/MemRef/IR/MemRef.h" 21119545f4SAlex Zinenko #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 22119545f4SAlex Zinenko #include "mlir/Dialect/SCF/SCF.h" 231ce752b7SAlex Zinenko #include "mlir/IR/ImplicitLocOpBuilder.h" 241ce752b7SAlex Zinenko #include "mlir/IR/SymbolTable.h" 25119545f4SAlex Zinenko #include "mlir/Transforms/DialectConversion.h" 26119545f4SAlex Zinenko 27119545f4SAlex Zinenko using namespace mlir; 28119545f4SAlex Zinenko 291ce752b7SAlex Zinenko /// Matches a block containing a "simple" reduction. The expected shape of the 301ce752b7SAlex Zinenko /// block is as follows. 311ce752b7SAlex Zinenko /// 321ce752b7SAlex Zinenko /// ^bb(%arg0, %arg1): 331ce752b7SAlex Zinenko /// %0 = OpTy(%arg0, %arg1) 341ce752b7SAlex Zinenko /// scf.reduce.return %0 351ce752b7SAlex Zinenko template <typename... OpTy> 361ce752b7SAlex Zinenko static bool matchSimpleReduction(Block &block) { 371ce752b7SAlex Zinenko if (block.empty() || llvm::hasSingleElement(block) || 381ce752b7SAlex Zinenko std::next(block.begin(), 2) != block.end()) 391ce752b7SAlex Zinenko return false; 402a876a71SDiego Caballero 412a876a71SDiego Caballero if (block.getNumArguments() != 2) 422a876a71SDiego Caballero return false; 432a876a71SDiego Caballero 442a876a71SDiego Caballero SmallVector<Operation *, 4> combinerOps; 452a876a71SDiego Caballero Value reducedVal = matchReduction({block.getArguments()[1]}, 462a876a71SDiego Caballero /*redPos=*/0, combinerOps); 472a876a71SDiego Caballero 482a876a71SDiego Caballero if (!reducedVal || !reducedVal.isa<BlockArgument>() || 492a876a71SDiego Caballero combinerOps.size() != 1) 502a876a71SDiego Caballero return false; 512a876a71SDiego Caballero 522a876a71SDiego Caballero return isa<OpTy...>(combinerOps[0]) && 531ce752b7SAlex Zinenko isa<scf::ReduceReturnOp>(block.back()) && 542a876a71SDiego Caballero block.front().getOperands() == block.getArguments(); 551ce752b7SAlex Zinenko } 561ce752b7SAlex Zinenko 571ce752b7SAlex Zinenko /// Matches a block containing a select-based min/max reduction. The types of 581ce752b7SAlex Zinenko /// select and compare operations are provided as template arguments. The 591ce752b7SAlex Zinenko /// comparison predicates suitable for min and max are provided as function 601ce752b7SAlex Zinenko /// arguments. If a reduction is matched, `ifMin` will be set if the reduction 611ce752b7SAlex Zinenko /// compute the minimum and unset if it computes the maximum, otherwise it 621ce752b7SAlex Zinenko /// remains unmodified. The expected shape of the block is as follows. 631ce752b7SAlex Zinenko /// 641ce752b7SAlex Zinenko /// ^bb(%arg0, %arg1): 651ce752b7SAlex Zinenko /// %0 = CompareOpTy(<one-of-predicates>, %arg0, %arg1) 661ce752b7SAlex Zinenko /// %1 = SelectOpTy(%0, %arg0, %arg1) // %arg0, %arg1 may be swapped here. 671ce752b7SAlex Zinenko /// scf.reduce.return %1 681ce752b7SAlex Zinenko template < 691ce752b7SAlex Zinenko typename CompareOpTy, typename SelectOpTy, 7062fea88bSJacques Pienaar typename Predicate = decltype(std::declval<CompareOpTy>().getPredicate())> 711ce752b7SAlex Zinenko static bool 721ce752b7SAlex Zinenko matchSelectReduction(Block &block, ArrayRef<Predicate> lessThanPredicates, 731ce752b7SAlex Zinenko ArrayRef<Predicate> greaterThanPredicates, bool &isMin) { 74dec8af70SRiver Riddle static_assert( 75dec8af70SRiver Riddle llvm::is_one_of<SelectOpTy, arith::SelectOp, LLVM::SelectOp>::value, 76dec8af70SRiver Riddle "only arithmetic and llvm select ops are supported"); 771ce752b7SAlex Zinenko 781ce752b7SAlex Zinenko // Expect exactly three operations in the block. 791ce752b7SAlex Zinenko if (block.empty() || llvm::hasSingleElement(block) || 801ce752b7SAlex Zinenko std::next(block.begin(), 2) == block.end() || 811ce752b7SAlex Zinenko std::next(block.begin(), 3) != block.end()) 821ce752b7SAlex Zinenko return false; 831ce752b7SAlex Zinenko 841ce752b7SAlex Zinenko // Check op kinds. 851ce752b7SAlex Zinenko auto compare = dyn_cast<CompareOpTy>(block.front()); 861ce752b7SAlex Zinenko auto select = dyn_cast<SelectOpTy>(block.front().getNextNode()); 871ce752b7SAlex Zinenko auto terminator = dyn_cast<scf::ReduceReturnOp>(block.back()); 881ce752b7SAlex Zinenko if (!compare || !select || !terminator) 891ce752b7SAlex Zinenko return false; 901ce752b7SAlex Zinenko 911ce752b7SAlex Zinenko // Block arguments must be compared. 921ce752b7SAlex Zinenko if (compare->getOperands() != block.getArguments()) 931ce752b7SAlex Zinenko return false; 941ce752b7SAlex Zinenko 951ce752b7SAlex Zinenko // Detect whether the comparison is less-than or greater-than, otherwise bail. 961ce752b7SAlex Zinenko bool isLess; 97cfb72fd3SJacques Pienaar if (llvm::find(lessThanPredicates, compare.getPredicate()) != 981ce752b7SAlex Zinenko lessThanPredicates.end()) { 991ce752b7SAlex Zinenko isLess = true; 100cfb72fd3SJacques Pienaar } else if (llvm::find(greaterThanPredicates, compare.getPredicate()) != 1011ce752b7SAlex Zinenko greaterThanPredicates.end()) { 1021ce752b7SAlex Zinenko isLess = false; 1031ce752b7SAlex Zinenko } else { 1041ce752b7SAlex Zinenko return false; 1051ce752b7SAlex Zinenko } 1061ce752b7SAlex Zinenko 107cfb72fd3SJacques Pienaar if (select.getCondition() != compare.getResult()) 1081ce752b7SAlex Zinenko return false; 1091ce752b7SAlex Zinenko 1101ce752b7SAlex Zinenko // Detect if the operands are swapped between cmpf and select. Match the 1111ce752b7SAlex Zinenko // comparison type with the requested type or with the opposite of the 1121ce752b7SAlex Zinenko // requested type if the operands are swapped. Use generic accessors because 1131ce752b7SAlex Zinenko // std and LLVM versions of select have different operand names but identical 1141ce752b7SAlex Zinenko // positions. 1151ce752b7SAlex Zinenko constexpr unsigned kTrueValue = 1; 1161ce752b7SAlex Zinenko constexpr unsigned kFalseValue = 2; 117cfb72fd3SJacques Pienaar bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() && 118cfb72fd3SJacques Pienaar select.getOperand(kFalseValue) == compare.getRhs(); 119cfb72fd3SJacques Pienaar bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() && 120cfb72fd3SJacques Pienaar select.getOperand(kFalseValue) == compare.getLhs(); 1211ce752b7SAlex Zinenko if (!sameOperands && !swappedOperands) 1221ce752b7SAlex Zinenko return false; 1231ce752b7SAlex Zinenko 124c0342a2dSJacques Pienaar if (select.getResult() != terminator.getResult()) 1251ce752b7SAlex Zinenko return false; 1261ce752b7SAlex Zinenko 1271ce752b7SAlex Zinenko // The reduction is a min if it uses less-than predicates with same operands 1281ce752b7SAlex Zinenko // or greather-than predicates with swapped operands. Similarly for max. 1291ce752b7SAlex Zinenko isMin = (isLess && sameOperands) || (!isLess && swappedOperands); 1301ce752b7SAlex Zinenko return isMin || (isLess & swappedOperands) || (!isLess && sameOperands); 1311ce752b7SAlex Zinenko } 1321ce752b7SAlex Zinenko 1331ce752b7SAlex Zinenko /// Returns the float semantics for the given float type. 1341ce752b7SAlex Zinenko static const llvm::fltSemantics &fltSemanticsForType(FloatType type) { 1351ce752b7SAlex Zinenko if (type.isF16()) 1361ce752b7SAlex Zinenko return llvm::APFloat::IEEEhalf(); 1371ce752b7SAlex Zinenko if (type.isF32()) 1381ce752b7SAlex Zinenko return llvm::APFloat::IEEEsingle(); 1391ce752b7SAlex Zinenko if (type.isF64()) 1401ce752b7SAlex Zinenko return llvm::APFloat::IEEEdouble(); 1411ce752b7SAlex Zinenko if (type.isF128()) 1421ce752b7SAlex Zinenko return llvm::APFloat::IEEEquad(); 1431ce752b7SAlex Zinenko if (type.isBF16()) 1441ce752b7SAlex Zinenko return llvm::APFloat::BFloat(); 1451ce752b7SAlex Zinenko if (type.isF80()) 1461ce752b7SAlex Zinenko return llvm::APFloat::x87DoubleExtended(); 1471ce752b7SAlex Zinenko llvm_unreachable("unknown float type"); 1481ce752b7SAlex Zinenko } 1491ce752b7SAlex Zinenko 1501ce752b7SAlex Zinenko /// Returns an attribute with the minimum (if `min` is set) or the maximum value 1511ce752b7SAlex Zinenko /// (otherwise) for the given float type. 1521ce752b7SAlex Zinenko static Attribute minMaxValueForFloat(Type type, bool min) { 1531ce752b7SAlex Zinenko auto fltType = type.cast<FloatType>(); 1541ce752b7SAlex Zinenko return FloatAttr::get( 1551ce752b7SAlex Zinenko type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min)); 1561ce752b7SAlex Zinenko } 1571ce752b7SAlex Zinenko 1581ce752b7SAlex Zinenko /// Returns an attribute with the signed integer minimum (if `min` is set) or 1591ce752b7SAlex Zinenko /// the maximum value (otherwise) for the given integer type, regardless of its 1601ce752b7SAlex Zinenko /// signedness semantics (only the width is considered). 1611ce752b7SAlex Zinenko static Attribute minMaxValueForSignedInt(Type type, bool min) { 1621ce752b7SAlex Zinenko auto intType = type.cast<IntegerType>(); 1631ce752b7SAlex Zinenko unsigned bitwidth = intType.getWidth(); 1641ce752b7SAlex Zinenko return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth) 1651ce752b7SAlex Zinenko : llvm::APInt::getSignedMaxValue(bitwidth)); 1661ce752b7SAlex Zinenko } 1671ce752b7SAlex Zinenko 1681ce752b7SAlex Zinenko /// Returns an attribute with the unsigned integer minimum (if `min` is set) or 1691ce752b7SAlex Zinenko /// the maximum value (otherwise) for the given integer type, regardless of its 1701ce752b7SAlex Zinenko /// signedness semantics (only the width is considered). 1711ce752b7SAlex Zinenko static Attribute minMaxValueForUnsignedInt(Type type, bool min) { 1721ce752b7SAlex Zinenko auto intType = type.cast<IntegerType>(); 1731ce752b7SAlex Zinenko unsigned bitwidth = intType.getWidth(); 1741ce752b7SAlex Zinenko return IntegerAttr::get(type, min ? llvm::APInt::getNullValue(bitwidth) 1751ce752b7SAlex Zinenko : llvm::APInt::getAllOnesValue(bitwidth)); 1761ce752b7SAlex Zinenko } 1771ce752b7SAlex Zinenko 1781ce752b7SAlex Zinenko /// Creates an OpenMP reduction declaration and inserts it into the provided 1791ce752b7SAlex Zinenko /// symbol table. The declaration has a constant initializer with the neutral 1801ce752b7SAlex Zinenko /// value `initValue`, and the reduction combiner carried over from `reduce`. 1811ce752b7SAlex Zinenko static omp::ReductionDeclareOp createDecl(PatternRewriter &builder, 1821ce752b7SAlex Zinenko SymbolTable &symbolTable, 1831ce752b7SAlex Zinenko scf::ReduceOp reduce, 1841ce752b7SAlex Zinenko Attribute initValue) { 1851ce752b7SAlex Zinenko OpBuilder::InsertionGuard guard(builder); 1861ce752b7SAlex Zinenko auto decl = builder.create<omp::ReductionDeclareOp>( 187c0342a2dSJacques Pienaar reduce.getLoc(), "__scf_reduction", reduce.getOperand().getType()); 1881ce752b7SAlex Zinenko symbolTable.insert(decl); 1891ce752b7SAlex Zinenko 190c0342a2dSJacques Pienaar Type type = reduce.getOperand().getType(); 1911ce752b7SAlex Zinenko builder.createBlock(&decl.initializerRegion(), decl.initializerRegion().end(), 192e084679fSRiver Riddle {type}, {reduce.getOperand().getLoc()}); 1931ce752b7SAlex Zinenko builder.setInsertionPointToEnd(&decl.initializerRegion().back()); 1941ce752b7SAlex Zinenko Value init = 1951ce752b7SAlex Zinenko builder.create<LLVM::ConstantOp>(reduce.getLoc(), type, initValue); 1961ce752b7SAlex Zinenko builder.create<omp::YieldOp>(reduce.getLoc(), init); 1971ce752b7SAlex Zinenko 1981ce752b7SAlex Zinenko Operation *terminator = &reduce.getRegion().front().back(); 1991ce752b7SAlex Zinenko assert(isa<scf::ReduceReturnOp>(terminator) && 2001ce752b7SAlex Zinenko "expected reduce op to be terminated by redure return"); 2011ce752b7SAlex Zinenko builder.setInsertionPoint(terminator); 2021ce752b7SAlex Zinenko builder.replaceOpWithNewOp<omp::YieldOp>(terminator, 2031ce752b7SAlex Zinenko terminator->getOperands()); 2041ce752b7SAlex Zinenko builder.inlineRegionBefore(reduce.getRegion(), decl.reductionRegion(), 2051ce752b7SAlex Zinenko decl.reductionRegion().end()); 2061ce752b7SAlex Zinenko return decl; 2071ce752b7SAlex Zinenko } 2081ce752b7SAlex Zinenko 2091ce752b7SAlex Zinenko /// Adds an atomic reduction combiner to the given OpenMP reduction declaration 2101ce752b7SAlex Zinenko /// using llvm.atomicrmw of the given kind. 2111ce752b7SAlex Zinenko static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder, 2121ce752b7SAlex Zinenko LLVM::AtomicBinOp atomicKind, 2131ce752b7SAlex Zinenko omp::ReductionDeclareOp decl, 2141ce752b7SAlex Zinenko scf::ReduceOp reduce) { 2151ce752b7SAlex Zinenko OpBuilder::InsertionGuard guard(builder); 216c0342a2dSJacques Pienaar Type type = reduce.getOperand().getType(); 2171ce752b7SAlex Zinenko Type ptrType = LLVM::LLVMPointerType::get(type); 218e084679fSRiver Riddle Location reduceOperandLoc = reduce.getOperand().getLoc(); 2191ce752b7SAlex Zinenko builder.createBlock(&decl.atomicReductionRegion(), 220e084679fSRiver Riddle decl.atomicReductionRegion().end(), {ptrType, ptrType}, 221e084679fSRiver Riddle {reduceOperandLoc, reduceOperandLoc}); 2221ce752b7SAlex Zinenko Block *atomicBlock = &decl.atomicReductionRegion().back(); 2231ce752b7SAlex Zinenko builder.setInsertionPointToEnd(atomicBlock); 2241ce752b7SAlex Zinenko Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(), 2251ce752b7SAlex Zinenko atomicBlock->getArgument(1)); 2261ce752b7SAlex Zinenko builder.create<LLVM::AtomicRMWOp>(reduce.getLoc(), type, atomicKind, 2271ce752b7SAlex Zinenko atomicBlock->getArgument(0), loaded, 2281ce752b7SAlex Zinenko LLVM::AtomicOrdering::monotonic); 2291ce752b7SAlex Zinenko builder.create<omp::YieldOp>(reduce.getLoc(), ArrayRef<Value>()); 2301ce752b7SAlex Zinenko return decl; 2311ce752b7SAlex Zinenko } 2321ce752b7SAlex Zinenko 2331ce752b7SAlex Zinenko /// Creates an OpenMP reduction declaration that corresponds to the given SCF 2341ce752b7SAlex Zinenko /// reduction and returns it. Recognizes common reductions in order to identify 2351ce752b7SAlex Zinenko /// the neutral value, necessary for the OpenMP declaration. If the reduction 2361ce752b7SAlex Zinenko /// cannot be recognized, returns null. 2371ce752b7SAlex Zinenko static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder, 2381ce752b7SAlex Zinenko scf::ReduceOp reduce) { 2391ce752b7SAlex Zinenko Operation *container = SymbolTable::getNearestSymbolTable(reduce); 2401ce752b7SAlex Zinenko SymbolTable symbolTable(container); 2411ce752b7SAlex Zinenko 2421ce752b7SAlex Zinenko // Insert reduction declarations in the symbol-table ancestor before the 2431ce752b7SAlex Zinenko // ancestor of the current insertion point. 2441ce752b7SAlex Zinenko Operation *insertionPoint = reduce; 2451ce752b7SAlex Zinenko while (insertionPoint->getParentOp() != container) 2461ce752b7SAlex Zinenko insertionPoint = insertionPoint->getParentOp(); 2471ce752b7SAlex Zinenko OpBuilder::InsertionGuard guard(builder); 2481ce752b7SAlex Zinenko builder.setInsertionPoint(insertionPoint); 2491ce752b7SAlex Zinenko 2501ce752b7SAlex Zinenko assert(llvm::hasSingleElement(reduce.getRegion()) && 2511ce752b7SAlex Zinenko "expected reduction region to have a single element"); 2521ce752b7SAlex Zinenko 2531ce752b7SAlex Zinenko // Match simple binary reductions that can be expressed with atomicrmw. 254c0342a2dSJacques Pienaar Type type = reduce.getOperand().getType(); 2551ce752b7SAlex Zinenko Block &reduction = reduce.getRegion().front(); 256a54f4eaeSMogball if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) { 2571ce752b7SAlex Zinenko omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, 2581ce752b7SAlex Zinenko builder.getFloatAttr(type, 0.0)); 2591ce752b7SAlex Zinenko return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce); 2601ce752b7SAlex Zinenko } 261a54f4eaeSMogball if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) { 2621ce752b7SAlex Zinenko omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, 2631ce752b7SAlex Zinenko builder.getIntegerAttr(type, 0)); 2641ce752b7SAlex Zinenko return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce); 2651ce752b7SAlex Zinenko } 266a54f4eaeSMogball if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) { 2671ce752b7SAlex Zinenko omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, 2681ce752b7SAlex Zinenko builder.getIntegerAttr(type, 0)); 2691ce752b7SAlex Zinenko return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce); 2701ce752b7SAlex Zinenko } 271a54f4eaeSMogball if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) { 2721ce752b7SAlex Zinenko omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, 2731ce752b7SAlex Zinenko builder.getIntegerAttr(type, 0)); 2741ce752b7SAlex Zinenko return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce); 2751ce752b7SAlex Zinenko } 276a54f4eaeSMogball if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) { 2771ce752b7SAlex Zinenko omp::ReductionDeclareOp decl = createDecl( 2781ce752b7SAlex Zinenko builder, symbolTable, reduce, 2791ce752b7SAlex Zinenko builder.getIntegerAttr( 2801ce752b7SAlex Zinenko type, llvm::APInt::getAllOnesValue(type.getIntOrFloatBitWidth()))); 2811ce752b7SAlex Zinenko return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce); 2821ce752b7SAlex Zinenko } 2831ce752b7SAlex Zinenko 2841ce752b7SAlex Zinenko // Match simple binary reductions that cannot be expressed with atomicrmw. 2851ce752b7SAlex Zinenko // TODO: add atomic region using cmpxchg (which needs atomic load to be 2861ce752b7SAlex Zinenko // available as an op). 287a54f4eaeSMogball if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) { 2881ce752b7SAlex Zinenko return createDecl(builder, symbolTable, reduce, 2891ce752b7SAlex Zinenko builder.getFloatAttr(type, 1.0)); 2901ce752b7SAlex Zinenko } 2911ce752b7SAlex Zinenko 2921ce752b7SAlex Zinenko // Match select-based min/max reductions. 2931ce752b7SAlex Zinenko bool isMin; 294dec8af70SRiver Riddle if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>( 295a54f4eaeSMogball reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE}, 296a54f4eaeSMogball {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) || 2971ce752b7SAlex Zinenko matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>( 2981ce752b7SAlex Zinenko reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole}, 2991ce752b7SAlex Zinenko {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) { 3001ce752b7SAlex Zinenko return createDecl(builder, symbolTable, reduce, 3011ce752b7SAlex Zinenko minMaxValueForFloat(type, !isMin)); 3021ce752b7SAlex Zinenko } 303dec8af70SRiver Riddle if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>( 304a54f4eaeSMogball reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle}, 305a54f4eaeSMogball {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) || 3061ce752b7SAlex Zinenko matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>( 3071ce752b7SAlex Zinenko reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle}, 3081ce752b7SAlex Zinenko {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) { 3091ce752b7SAlex Zinenko omp::ReductionDeclareOp decl = createDecl( 3101ce752b7SAlex Zinenko builder, symbolTable, reduce, minMaxValueForSignedInt(type, !isMin)); 3111ce752b7SAlex Zinenko return addAtomicRMW(builder, 3121ce752b7SAlex Zinenko isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max, 3131ce752b7SAlex Zinenko decl, reduce); 3141ce752b7SAlex Zinenko } 315dec8af70SRiver Riddle if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>( 316a54f4eaeSMogball reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule}, 317a54f4eaeSMogball {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) || 3181ce752b7SAlex Zinenko matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>( 3191ce752b7SAlex Zinenko reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule}, 3201ce752b7SAlex Zinenko {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) { 3211ce752b7SAlex Zinenko omp::ReductionDeclareOp decl = createDecl( 3221ce752b7SAlex Zinenko builder, symbolTable, reduce, minMaxValueForUnsignedInt(type, !isMin)); 3231ce752b7SAlex Zinenko return addAtomicRMW( 3241ce752b7SAlex Zinenko builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax, 3251ce752b7SAlex Zinenko decl, reduce); 3261ce752b7SAlex Zinenko } 3271ce752b7SAlex Zinenko 3281ce752b7SAlex Zinenko return nullptr; 3291ce752b7SAlex Zinenko } 3301ce752b7SAlex Zinenko 331119545f4SAlex Zinenko namespace { 332119545f4SAlex Zinenko 333119545f4SAlex Zinenko struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { 334119545f4SAlex Zinenko using OpRewritePattern<scf::ParallelOp>::OpRewritePattern; 335119545f4SAlex Zinenko 336119545f4SAlex Zinenko LogicalResult matchAndRewrite(scf::ParallelOp parallelOp, 337119545f4SAlex Zinenko PatternRewriter &rewriter) const override { 3381ce752b7SAlex Zinenko // Declare reductions. 3391ce752b7SAlex Zinenko // TODO: consider checking it here is already a compatible reduction 3401ce752b7SAlex Zinenko // declaration and use it instead of redeclaring. 3411ce752b7SAlex Zinenko SmallVector<Attribute> reductionDeclSymbols; 3421ce752b7SAlex Zinenko for (auto reduce : parallelOp.getOps<scf::ReduceOp>()) { 3431ce752b7SAlex Zinenko omp::ReductionDeclareOp decl = declareReduction(rewriter, reduce); 3441ce752b7SAlex Zinenko if (!decl) 3451ce752b7SAlex Zinenko return failure(); 3461ce752b7SAlex Zinenko reductionDeclSymbols.push_back( 3471ce752b7SAlex Zinenko SymbolRefAttr::get(rewriter.getContext(), decl.sym_name())); 3481ce752b7SAlex Zinenko } 3491ce752b7SAlex Zinenko 3501ce752b7SAlex Zinenko // Allocate reduction variables. Make sure the we don't overflow the stack 3511ce752b7SAlex Zinenko // with local `alloca`s by saving and restoring the stack pointer. 3521ce752b7SAlex Zinenko Location loc = parallelOp.getLoc(); 3531ce752b7SAlex Zinenko Value one = rewriter.create<LLVM::ConstantOp>( 3541ce752b7SAlex Zinenko loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1)); 3551ce752b7SAlex Zinenko SmallVector<Value> reductionVariables; 3561ce752b7SAlex Zinenko reductionVariables.reserve(parallelOp.getNumReductions()); 357c0342a2dSJacques Pienaar for (Value init : parallelOp.getInitVals()) { 3581ce752b7SAlex Zinenko assert((LLVM::isCompatibleType(init.getType()) || 3591ce752b7SAlex Zinenko init.getType().isa<LLVM::PointerElementTypeInterface>()) && 3601ce752b7SAlex Zinenko "cannot create a reduction variable if the type is not an LLVM " 3611ce752b7SAlex Zinenko "pointer element"); 3621ce752b7SAlex Zinenko Value storage = rewriter.create<LLVM::AllocaOp>( 3631ce752b7SAlex Zinenko loc, LLVM::LLVMPointerType::get(init.getType()), one, 0); 3641ce752b7SAlex Zinenko rewriter.create<LLVM::StoreOp>(loc, init, storage); 3651ce752b7SAlex Zinenko reductionVariables.push_back(storage); 3661ce752b7SAlex Zinenko } 3671ce752b7SAlex Zinenko 3681ce752b7SAlex Zinenko // Replace the reduction operations contained in this loop. Must be done 3691ce752b7SAlex Zinenko // here rather than in a separate pattern to have access to the list of 3701ce752b7SAlex Zinenko // reduction variables. 3711ce752b7SAlex Zinenko for (auto pair : 3721ce752b7SAlex Zinenko llvm::zip(parallelOp.getOps<scf::ReduceOp>(), reductionVariables)) { 3731ce752b7SAlex Zinenko OpBuilder::InsertionGuard guard(rewriter); 3741ce752b7SAlex Zinenko scf::ReduceOp reduceOp = std::get<0>(pair); 3751ce752b7SAlex Zinenko rewriter.setInsertionPoint(reduceOp); 3761ce752b7SAlex Zinenko rewriter.replaceOpWithNewOp<omp::ReductionOp>( 377c0342a2dSJacques Pienaar reduceOp, reduceOp.getOperand(), std::get<1>(pair)); 3781ce752b7SAlex Zinenko } 3791ce752b7SAlex Zinenko 3801ce752b7SAlex Zinenko // Create the parallel wrapper. 3811ce752b7SAlex Zinenko auto ompParallel = rewriter.create<omp::ParallelOp>(loc); 3821ce752b7SAlex Zinenko { 38378fb4f9dSWilliam S. Moses 3841ce752b7SAlex Zinenko OpBuilder::InsertionGuard guard(rewriter); 3851ce752b7SAlex Zinenko rewriter.createBlock(&ompParallel.region()); 3861ce752b7SAlex Zinenko 387119545f4SAlex Zinenko // Replace the loop. 388*bf6477ebSWilliam S. Moses { 389*bf6477ebSWilliam S. Moses OpBuilder::InsertionGuard allocaGuard(rewriter); 390119545f4SAlex Zinenko auto loop = rewriter.create<omp::WsLoopOp>( 391c0342a2dSJacques Pienaar parallelOp.getLoc(), parallelOp.getLowerBound(), 392c0342a2dSJacques Pienaar parallelOp.getUpperBound(), parallelOp.getStep()); 393*bf6477ebSWilliam S. Moses rewriter.create<omp::TerminatorOp>(loc); 3941ce752b7SAlex Zinenko 395c0342a2dSJacques Pienaar rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.region(), 396119545f4SAlex Zinenko loop.region().begin()); 397*bf6477ebSWilliam S. Moses 398*bf6477ebSWilliam S. Moses Block *ops = rewriter.splitBlock(&*loop.region().begin(), 399*bf6477ebSWilliam S. Moses loop.region().begin()->begin()); 400*bf6477ebSWilliam S. Moses 401*bf6477ebSWilliam S. Moses rewriter.setInsertionPointToStart(&*loop.region().begin()); 402*bf6477ebSWilliam S. Moses 403*bf6477ebSWilliam S. Moses auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(), 404*bf6477ebSWilliam S. Moses TypeRange()); 405*bf6477ebSWilliam S. Moses rewriter.create<omp::YieldOp>(loc, ValueRange()); 406*bf6477ebSWilliam S. Moses Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion()); 407*bf6477ebSWilliam S. Moses rewriter.mergeBlocks(ops, scopeBlock); 408*bf6477ebSWilliam S. Moses auto oldYield = cast<scf::YieldOp>(scopeBlock->getTerminator()); 409*bf6477ebSWilliam S. Moses rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin()); 410*bf6477ebSWilliam S. Moses rewriter.replaceOpWithNewOp<memref::AllocaScopeReturnOp>( 411*bf6477ebSWilliam S. Moses oldYield, oldYield->getOperands()); 4121ce752b7SAlex Zinenko if (!reductionVariables.empty()) { 4131ce752b7SAlex Zinenko loop.reductionsAttr( 4141ce752b7SAlex Zinenko ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols)); 4151ce752b7SAlex Zinenko loop.reduction_varsMutable().append(reductionVariables); 4161ce752b7SAlex Zinenko } 4171ce752b7SAlex Zinenko } 41878fb4f9dSWilliam S. Moses } 419973cb2c3SWilliam S. Moses 4201ce752b7SAlex Zinenko // Load loop results. 4211ce752b7SAlex Zinenko SmallVector<Value> results; 4221ce752b7SAlex Zinenko results.reserve(reductionVariables.size()); 4231ce752b7SAlex Zinenko for (Value variable : reductionVariables) { 4241ce752b7SAlex Zinenko Value res = rewriter.create<LLVM::LoadOp>(loc, variable); 4251ce752b7SAlex Zinenko results.push_back(res); 4261ce752b7SAlex Zinenko } 4271ce752b7SAlex Zinenko rewriter.replaceOp(parallelOp, results); 4281ce752b7SAlex Zinenko 429119545f4SAlex Zinenko return success(); 430119545f4SAlex Zinenko } 431119545f4SAlex Zinenko }; 432119545f4SAlex Zinenko 433119545f4SAlex Zinenko /// Applies the conversion patterns in the given function. 4341ce752b7SAlex Zinenko static LogicalResult applyPatterns(ModuleOp module) { 4351ce752b7SAlex Zinenko ConversionTarget target(*module.getContext()); 4361ce752b7SAlex Zinenko target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(); 43778fb4f9dSWilliam S. Moses target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect, 43878fb4f9dSWilliam S. Moses memref::MemRefDialect>(); 439119545f4SAlex Zinenko 4401ce752b7SAlex Zinenko RewritePatternSet patterns(module.getContext()); 4411ce752b7SAlex Zinenko patterns.add<ParallelOpLowering>(module.getContext()); 44279d7f618SChris Lattner FrozenRewritePatternSet frozen(std::move(patterns)); 4431ce752b7SAlex Zinenko return applyPartialConversion(module, target, frozen); 444119545f4SAlex Zinenko } 445119545f4SAlex Zinenko 446119545f4SAlex Zinenko /// A pass converting SCF operations to OpenMP operations. 447119545f4SAlex Zinenko struct SCFToOpenMPPass : public ConvertSCFToOpenMPBase<SCFToOpenMPPass> { 448119545f4SAlex Zinenko /// Pass entry point. 4491ce752b7SAlex Zinenko void runOnOperation() override { 4501ce752b7SAlex Zinenko if (failed(applyPatterns(getOperation()))) 451119545f4SAlex Zinenko signalPassFailure(); 452119545f4SAlex Zinenko } 453119545f4SAlex Zinenko }; 454119545f4SAlex Zinenko 455be0a7e9fSMehdi Amini } // namespace 456119545f4SAlex Zinenko 4571ce752b7SAlex Zinenko std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertSCFToOpenMPPass() { 458119545f4SAlex Zinenko return std::make_unique<SCFToOpenMPPass>(); 459119545f4SAlex Zinenko } 460