1 //===- ReductionNode.cpp - Reduction Node Implementation -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the reduction nodes which are used to track of the 10 // metadata for a specific generated variant within a reduction pass and are the 11 // building blocks of the reduction tree structure. A reduction tree is used to 12 // keep track of the different generated variants throughout a reduction pass in 13 // the MLIR Reduce tool. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Reducer/ReductionNode.h" 18 #include "mlir/IR/BlockAndValueMapping.h" 19 #include "llvm/ADT/STLExtras.h" 20 21 #include <algorithm> 22 #include <limits> 23 24 using namespace mlir; 25 26 ReductionNode::ReductionNode( 27 ReductionNode *parentNode, const std::vector<Range> &ranges, 28 llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator) 29 /// Root node will have the parent pointer point to themselves. 30 : parent(parentNode == nullptr ? this : parentNode), 31 size(std::numeric_limits<size_t>::max()), ranges(ranges), 32 startRanges(ranges), allocator(allocator) { 33 if (parent != this) 34 if (failed(initialize(parent->getModule(), parent->getRegion()))) 35 llvm_unreachable("unexpected initialization failure"); 36 } 37 38 LogicalResult ReductionNode::initialize(ModuleOp parentModule, 39 Region &targetRegion) { 40 // Use the mapper help us find the corresponding region after module clone. 41 BlockAndValueMapping mapper; 42 module = cast<ModuleOp>(parentModule->clone(mapper)); 43 // Use the first block of targetRegion to locate the cloned region. 44 Block *block = mapper.lookup(&*targetRegion.begin()); 45 region = block->getParent(); 46 return success(); 47 } 48 49 /// If we haven't explored any variants from this node, we will create N 50 /// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the 51 /// max element in `ranges` and create 2 new variants for each call. 52 ArrayRef<ReductionNode *> ReductionNode::generateNewVariants() { 53 int oldNumVariant = getVariants().size(); 54 55 auto createNewNode = [this](const std::vector<Range> &ranges) { 56 return new (allocator.Allocate()) ReductionNode(this, ranges, allocator); 57 }; 58 59 // If we haven't created new variant, then we can create varients by removing 60 // each of them respectively. For example, given {{1, 3}, {4, 9}}, we can 61 // produce variants with range {{1, 3}} and {{4, 9}}. 62 if (variants.empty() && getRanges().size() > 1) { 63 for (const Range &range : getRanges()) { 64 std::vector<Range> subRanges = getRanges(); 65 llvm::erase_value(subRanges, range); 66 variants.push_back(createNewNode(subRanges)); 67 } 68 69 return getVariants().drop_front(oldNumVariant); 70 } 71 72 // At here, we have created the type of variants mentioned above. We would 73 // like to split the max range into 2 to create 2 new variants. Continue on 74 // the above example, we split the range {4, 9} into {4, 6}, {6, 9}, and 75 // create two variants with range {{1, 3}, {4, 6}} and {{1, 3}, {6, 9}}. The 76 // final ranges vector will be {{1, 3}, {4, 6}, {6, 9}}. 77 auto maxElement = std::max_element( 78 ranges.begin(), ranges.end(), [](const Range &lhs, const Range &rhs) { 79 return (lhs.second - lhs.first) > (rhs.second - rhs.first); 80 }); 81 82 // The length of range is less than 1, we can't split it to create new 83 // variant. 84 if (maxElement->second - maxElement->first <= 1) 85 return {}; 86 87 Range maxRange = *maxElement; 88 std::vector<Range> subRanges = getRanges(); 89 auto subRangesIter = subRanges.begin() + (maxElement - ranges.begin()); 90 int half = (maxRange.first + maxRange.second) / 2; 91 *subRangesIter = std::make_pair(maxRange.first, half); 92 variants.push_back(createNewNode(subRanges)); 93 *subRangesIter = std::make_pair(half, maxRange.second); 94 variants.push_back(createNewNode(subRanges)); 95 96 auto it = ranges.insert(maxElement, std::make_pair(half, maxRange.second)); 97 it = ranges.insert(it, std::make_pair(maxRange.first, half)); 98 // Remove the range that has been split. 99 ranges.erase(it + 2); 100 101 return getVariants().drop_front(oldNumVariant); 102 } 103 104 void ReductionNode::update(std::pair<Tester::Interestingness, size_t> result) { 105 std::tie(interesting, size) = result; 106 // After applying reduction, the number of operation in the region may have 107 // changed. Non-interesting case won't be explored thus it's safe to keep it 108 // in a stale status. 109 if (interesting == Tester::Interestingness::True) { 110 // This module may has been updated. Reset the range. 111 ranges.clear(); 112 ranges.emplace_back(0, std::distance(region->op_begin(), region->op_end())); 113 } else { 114 // Release the uninteresting module to save some memory. 115 module.release()->erase(); 116 } 117 } 118 119 ArrayRef<ReductionNode *> 120 ReductionNode::iterator<SinglePath>::getNeighbors(ReductionNode *node) { 121 // Single Path: Traverses the smallest successful variant at each level until 122 // no new successful variants can be created at that level. 123 ArrayRef<ReductionNode *> variantsFromParent = 124 node->getParent()->getVariants(); 125 126 // The parent node created several variants and they may be waiting for 127 // examing interestingness. In Single Path approach, we will select the 128 // smallest variant to continue our exploration. Thus we should wait until the 129 // last variant to be examed then do the following traversal decision. 130 if (!llvm::all_of(variantsFromParent, [](ReductionNode *node) { 131 return node->isInteresting() != Tester::Interestingness::Untested; 132 })) { 133 return {}; 134 } 135 136 ReductionNode *smallest = nullptr; 137 for (ReductionNode *node : variantsFromParent) { 138 if (node->isInteresting() != Tester::Interestingness::True) 139 continue; 140 if (smallest == nullptr || node->getSize() < smallest->getSize()) 141 smallest = node; 142 } 143 144 if (smallest != nullptr && 145 smallest->getSize() < node->getParent()->getSize()) { 146 // We got a smallest one, keep traversing from this node. 147 node = smallest; 148 } else { 149 // None of these variants is interesting, let the parent node to generate 150 // more variants. 151 node = node->getParent(); 152 } 153 154 return node->generateNewVariants(); 155 } 156