1c25b20c0SAlex Zinenko //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===// 2c25b20c0SAlex Zinenko // 3c25b20c0SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4c25b20c0SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 5c25b20c0SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6c25b20c0SAlex Zinenko // 7c25b20c0SAlex Zinenko //===----------------------------------------------------------------------===// 8c25b20c0SAlex Zinenko // 9c25b20c0SAlex Zinenko // This file implements loop fusion on parallel loops. 10c25b20c0SAlex Zinenko // 11c25b20c0SAlex Zinenko //===----------------------------------------------------------------------===// 12c25b20c0SAlex Zinenko 13c25b20c0SAlex Zinenko #include "PassDetail.h" 14e2310704SJulian Gross #include "mlir/Dialect/MemRef/IR/MemRef.h" 15c25b20c0SAlex Zinenko #include "mlir/Dialect/SCF/Passes.h" 16c25b20c0SAlex Zinenko #include "mlir/Dialect/SCF/SCF.h" 17c25b20c0SAlex Zinenko #include "mlir/Dialect/SCF/Transforms.h" 18c25b20c0SAlex Zinenko #include "mlir/Dialect/StandardOps/IR/Ops.h" 19c25b20c0SAlex Zinenko #include "mlir/IR/BlockAndValueMapping.h" 20c25b20c0SAlex Zinenko #include "mlir/IR/Builders.h" 21c25b20c0SAlex Zinenko #include "mlir/IR/OpDefinition.h" 22c25b20c0SAlex Zinenko 23c25b20c0SAlex Zinenko using namespace mlir; 24c25b20c0SAlex Zinenko using namespace mlir::scf; 25c25b20c0SAlex Zinenko 26c25b20c0SAlex Zinenko /// Verify there are no nested ParallelOps. 27c25b20c0SAlex Zinenko static bool hasNestedParallelOp(ParallelOp ploop) { 28c25b20c0SAlex Zinenko auto walkResult = 29c25b20c0SAlex Zinenko ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); }); 30c25b20c0SAlex Zinenko return walkResult.wasInterrupted(); 31c25b20c0SAlex Zinenko } 32c25b20c0SAlex Zinenko 33c25b20c0SAlex Zinenko /// Verify equal iteration spaces. 34c25b20c0SAlex Zinenko static bool equalIterationSpaces(ParallelOp firstPloop, 35c25b20c0SAlex Zinenko ParallelOp secondPloop) { 36c25b20c0SAlex Zinenko if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) 37c25b20c0SAlex Zinenko return false; 38c25b20c0SAlex Zinenko 39c25b20c0SAlex Zinenko auto matchOperands = [&](const OperandRange &lhs, 40c25b20c0SAlex Zinenko const OperandRange &rhs) -> bool { 41c25b20c0SAlex Zinenko // TODO: Extend this to support aliases and equal constants. 42c25b20c0SAlex Zinenko return std::equal(lhs.begin(), lhs.end(), rhs.begin()); 43c25b20c0SAlex Zinenko }; 44*c0342a2dSJacques Pienaar return matchOperands(firstPloop.getLowerBound(), 45*c0342a2dSJacques Pienaar secondPloop.getLowerBound()) && 46*c0342a2dSJacques Pienaar matchOperands(firstPloop.getUpperBound(), 47*c0342a2dSJacques Pienaar secondPloop.getUpperBound()) && 48*c0342a2dSJacques Pienaar matchOperands(firstPloop.getStep(), secondPloop.getStep()); 49c25b20c0SAlex Zinenko } 50c25b20c0SAlex Zinenko 51c25b20c0SAlex Zinenko /// Checks if the parallel loops have mixed access to the same buffers. Returns 52c25b20c0SAlex Zinenko /// `true` if the first parallel loop writes to the same indices that the second 53c25b20c0SAlex Zinenko /// loop reads. 54c25b20c0SAlex Zinenko static bool haveNoReadsAfterWriteExceptSameIndex( 55c25b20c0SAlex Zinenko ParallelOp firstPloop, ParallelOp secondPloop, 56c25b20c0SAlex Zinenko const BlockAndValueMapping &firstToSecondPloopIndices) { 57c25b20c0SAlex Zinenko DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores; 58e2310704SJulian Gross firstPloop.getBody()->walk([&](memref::StoreOp store) { 59c25b20c0SAlex Zinenko bufferStores[store.getMemRef()].push_back(store.indices()); 60c25b20c0SAlex Zinenko }); 61e2310704SJulian Gross auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) { 62c25b20c0SAlex Zinenko // Stop if the memref is defined in secondPloop body. Careful alias analysis 63c25b20c0SAlex Zinenko // is needed. 64c25b20c0SAlex Zinenko auto *memrefDef = load.getMemRef().getDefiningOp(); 65c4a04059SChristian Sigg if (memrefDef && memrefDef->getBlock() == load->getBlock()) 66c25b20c0SAlex Zinenko return WalkResult::interrupt(); 67c25b20c0SAlex Zinenko 68c25b20c0SAlex Zinenko auto write = bufferStores.find(load.getMemRef()); 69c25b20c0SAlex Zinenko if (write == bufferStores.end()) 70c25b20c0SAlex Zinenko return WalkResult::advance(); 71c25b20c0SAlex Zinenko 72c25b20c0SAlex Zinenko // Allow only single write access per buffer. 73c25b20c0SAlex Zinenko if (write->second.size() != 1) 74c25b20c0SAlex Zinenko return WalkResult::interrupt(); 75c25b20c0SAlex Zinenko 76c25b20c0SAlex Zinenko // Check that the load indices of secondPloop coincide with store indices of 77c25b20c0SAlex Zinenko // firstPloop for the same memrefs. 78c25b20c0SAlex Zinenko auto storeIndices = write->second.front(); 79c25b20c0SAlex Zinenko auto loadIndices = load.indices(); 80c25b20c0SAlex Zinenko if (storeIndices.size() != loadIndices.size()) 81c25b20c0SAlex Zinenko return WalkResult::interrupt(); 82c25b20c0SAlex Zinenko for (int i = 0, e = storeIndices.size(); i < e; ++i) { 83c25b20c0SAlex Zinenko if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) != 84c25b20c0SAlex Zinenko loadIndices[i]) 85c25b20c0SAlex Zinenko return WalkResult::interrupt(); 86c25b20c0SAlex Zinenko } 87c25b20c0SAlex Zinenko return WalkResult::advance(); 88c25b20c0SAlex Zinenko }); 89c25b20c0SAlex Zinenko return !walkResult.wasInterrupted(); 90c25b20c0SAlex Zinenko } 91c25b20c0SAlex Zinenko 92c25b20c0SAlex Zinenko /// Analyzes dependencies in the most primitive way by checking simple read and 93c25b20c0SAlex Zinenko /// write patterns. 94c25b20c0SAlex Zinenko static LogicalResult 95c25b20c0SAlex Zinenko verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, 96c25b20c0SAlex Zinenko const BlockAndValueMapping &firstToSecondPloopIndices) { 97c25b20c0SAlex Zinenko if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop, 98c25b20c0SAlex Zinenko firstToSecondPloopIndices)) 99c25b20c0SAlex Zinenko return failure(); 100c25b20c0SAlex Zinenko 101c25b20c0SAlex Zinenko BlockAndValueMapping secondToFirstPloopIndices; 102c25b20c0SAlex Zinenko secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(), 103c25b20c0SAlex Zinenko firstPloop.getBody()->getArguments()); 104c25b20c0SAlex Zinenko return success(haveNoReadsAfterWriteExceptSameIndex( 105c25b20c0SAlex Zinenko secondPloop, firstPloop, secondToFirstPloopIndices)); 106c25b20c0SAlex Zinenko } 107c25b20c0SAlex Zinenko 108c25b20c0SAlex Zinenko static bool 109c25b20c0SAlex Zinenko isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, 110c25b20c0SAlex Zinenko const BlockAndValueMapping &firstToSecondPloopIndices) { 111c25b20c0SAlex Zinenko return !hasNestedParallelOp(firstPloop) && 112c25b20c0SAlex Zinenko !hasNestedParallelOp(secondPloop) && 113c25b20c0SAlex Zinenko equalIterationSpaces(firstPloop, secondPloop) && 114c25b20c0SAlex Zinenko succeeded(verifyDependencies(firstPloop, secondPloop, 115c25b20c0SAlex Zinenko firstToSecondPloopIndices)); 116c25b20c0SAlex Zinenko } 117c25b20c0SAlex Zinenko 118c25b20c0SAlex Zinenko /// Prepends operations of firstPloop's body into secondPloop's body. 119c25b20c0SAlex Zinenko static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop, 120c25b20c0SAlex Zinenko OpBuilder b) { 121c25b20c0SAlex Zinenko BlockAndValueMapping firstToSecondPloopIndices; 122c25b20c0SAlex Zinenko firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(), 123c25b20c0SAlex Zinenko secondPloop.getBody()->getArguments()); 124c25b20c0SAlex Zinenko 125c25b20c0SAlex Zinenko if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices)) 126c25b20c0SAlex Zinenko return; 127c25b20c0SAlex Zinenko 128c25b20c0SAlex Zinenko b.setInsertionPointToStart(secondPloop.getBody()); 129c25b20c0SAlex Zinenko for (auto &op : firstPloop.getBody()->without_terminator()) 130c25b20c0SAlex Zinenko b.clone(op, firstToSecondPloopIndices); 131c25b20c0SAlex Zinenko firstPloop.erase(); 132c25b20c0SAlex Zinenko } 133c25b20c0SAlex Zinenko 134c25b20c0SAlex Zinenko void mlir::scf::naivelyFuseParallelOps(Region ®ion) { 135c25b20c0SAlex Zinenko OpBuilder b(region); 136c25b20c0SAlex Zinenko // Consider every single block and attempt to fuse adjacent loops. 137c25b20c0SAlex Zinenko for (auto &block : region) { 138c25b20c0SAlex Zinenko SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}}; 139c25b20c0SAlex Zinenko // Not using `walk()` to traverse only top-level parallel loops and also 140c25b20c0SAlex Zinenko // make sure that there are no side-effecting ops between the parallel 141c25b20c0SAlex Zinenko // loops. 142c25b20c0SAlex Zinenko bool noSideEffects = true; 143c25b20c0SAlex Zinenko for (auto &op : block) { 144c25b20c0SAlex Zinenko if (auto ploop = dyn_cast<ParallelOp>(op)) { 145c25b20c0SAlex Zinenko if (noSideEffects) { 146c25b20c0SAlex Zinenko ploopChains.back().push_back(ploop); 147c25b20c0SAlex Zinenko } else { 148c25b20c0SAlex Zinenko ploopChains.push_back({ploop}); 149c25b20c0SAlex Zinenko noSideEffects = true; 150c25b20c0SAlex Zinenko } 151c25b20c0SAlex Zinenko continue; 152c25b20c0SAlex Zinenko } 153c25b20c0SAlex Zinenko // TODO: Handle region side effects properly. 154c25b20c0SAlex Zinenko noSideEffects &= 155c25b20c0SAlex Zinenko MemoryEffectOpInterface::hasNoEffect(&op) && op.getNumRegions() == 0; 156c25b20c0SAlex Zinenko } 157c25b20c0SAlex Zinenko for (ArrayRef<ParallelOp> ploops : ploopChains) { 158c25b20c0SAlex Zinenko for (int i = 0, e = ploops.size(); i + 1 < e; ++i) 159c25b20c0SAlex Zinenko fuseIfLegal(ploops[i], ploops[i + 1], b); 160c25b20c0SAlex Zinenko } 161c25b20c0SAlex Zinenko } 162c25b20c0SAlex Zinenko } 163c25b20c0SAlex Zinenko 164c25b20c0SAlex Zinenko namespace { 165c25b20c0SAlex Zinenko struct ParallelLoopFusion 1664bcd08ebSStephan Herhut : public SCFParallelLoopFusionBase<ParallelLoopFusion> { 167c25b20c0SAlex Zinenko void runOnOperation() override { 168c25b20c0SAlex Zinenko getOperation()->walk([&](Operation *child) { 169c25b20c0SAlex Zinenko for (Region ®ion : child->getRegions()) 170c25b20c0SAlex Zinenko naivelyFuseParallelOps(region); 171c25b20c0SAlex Zinenko }); 172c25b20c0SAlex Zinenko } 173c25b20c0SAlex Zinenko }; 174c25b20c0SAlex Zinenko } // namespace 175c25b20c0SAlex Zinenko 176c25b20c0SAlex Zinenko std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() { 177c25b20c0SAlex Zinenko return std::make_unique<ParallelLoopFusion>(); 178c25b20c0SAlex Zinenko } 179