1 //===- TestLoopFusion.cpp - Test loop fusion ------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to test various loop fusion utility functions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/Analysis/Utils.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Affine/LoopFusionUtils.h" 16 #include "mlir/Dialect/Affine/LoopUtils.h" 17 #include "mlir/Pass/Pass.h" 18 19 #define DEBUG_TYPE "test-loop-fusion" 20 21 using namespace mlir; 22 23 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); 24 25 static llvm::cl::opt<bool> clTestDependenceCheck( 26 "test-loop-fusion-dependence-check", 27 llvm::cl::desc("Enable testing of loop fusion dependence check"), 28 llvm::cl::cat(clOptionsCategory)); 29 30 static llvm::cl::opt<bool> clTestSliceComputation( 31 "test-loop-fusion-slice-computation", 32 llvm::cl::desc("Enable testing of loop fusion slice computation"), 33 llvm::cl::cat(clOptionsCategory)); 34 35 static llvm::cl::opt<bool> clTestLoopFusionTransformation( 36 "test-loop-fusion-transformation", 37 llvm::cl::desc("Enable testing of loop fusion transformation"), 38 llvm::cl::cat(clOptionsCategory)); 39 40 namespace { 41 42 struct TestLoopFusion 43 : public PassWrapper<TestLoopFusion, OperationPass<FuncOp>> { 44 StringRef getArgument() const final { return "test-loop-fusion"; } 45 StringRef getDescription() const final { 46 return "Tests loop fusion utility functions."; 47 } 48 void runOnOperation() override; 49 }; 50 51 } // namespace 52 53 // Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths 54 // in range ['loopDepth' + 1, 'maxLoopDepth']. 55 // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists. 56 // Returns false as IR is not transformed. 57 static bool testDependenceCheck(AffineForOp srcForOp, AffineForOp dstForOp, 58 unsigned i, unsigned j, unsigned loopDepth, 59 unsigned maxLoopDepth) { 60 mlir::ComputationSliceState sliceUnion; 61 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) { 62 FusionResult result = 63 mlir::canFuseLoops(srcForOp, dstForOp, d, &sliceUnion); 64 if (result.value == FusionResult::FailBlockDependence) { 65 srcForOp->emitRemark("block-level dependence preventing" 66 " fusion of loop nest ") 67 << i << " into loop nest " << j << " at depth " << loopDepth; 68 } 69 } 70 return false; 71 } 72 73 // Returns the index of 'op' in its block. 74 static unsigned getBlockIndex(Operation &op) { 75 unsigned index = 0; 76 for (auto &opX : *op.getBlock()) { 77 if (&op == &opX) 78 break; 79 ++index; 80 } 81 return index; 82 } 83 84 // Returns a string representation of 'sliceUnion'. 85 static std::string getSliceStr(const mlir::ComputationSliceState &sliceUnion) { 86 std::string result; 87 llvm::raw_string_ostream os(result); 88 // Slice insertion point format [loop-depth, operation-block-index] 89 unsigned ipd = getNestingDepth(&*sliceUnion.insertPoint); 90 unsigned ipb = getBlockIndex(*sliceUnion.insertPoint); 91 os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb) 92 << ")"; 93 assert(sliceUnion.lbs.size() == sliceUnion.ubs.size()); 94 os << " loop bounds: "; 95 for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) { 96 os << '['; 97 sliceUnion.lbs[k].print(os); 98 os << ", "; 99 sliceUnion.ubs[k].print(os); 100 os << "] "; 101 } 102 return os.str(); 103 } 104 105 /// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths 106 /// in range ['loopDepth' + 1, 'maxLoopDepth']. 107 /// Emits a string representation of the slice union as a remark on 'loops[j]' 108 /// and marks this as incorrect slice if the slice is invalid. Returns false as 109 /// IR is not transformed. 110 static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB, 111 unsigned i, unsigned j, unsigned loopDepth, 112 unsigned maxLoopDepth) { 113 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) { 114 mlir::ComputationSliceState sliceUnion; 115 FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion); 116 if (result.value == FusionResult::Success) { 117 forOpB->emitRemark("slice (") 118 << " src loop: " << i << ", dst loop: " << j << ", depth: " << d 119 << " : " << getSliceStr(sliceUnion) << ")"; 120 } else if (result.value == FusionResult::FailIncorrectSlice) { 121 forOpB->emitRemark("Incorrect slice (") 122 << " src loop: " << i << ", dst loop: " << j << ", depth: " << d 123 << " : " << getSliceStr(sliceUnion) << ")"; 124 } 125 } 126 return false; 127 } 128 129 // Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range 130 // ['loopDepth' + 1, 'maxLoopDepth']. 131 // Returns true if loops were successfully fused, false otherwise. 132 static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB, 133 unsigned i, unsigned j, 134 unsigned loopDepth, 135 unsigned maxLoopDepth) { 136 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) { 137 mlir::ComputationSliceState sliceUnion; 138 FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion); 139 if (result.value == FusionResult::Success) { 140 mlir::fuseLoops(forOpA, forOpB, sliceUnion); 141 // Note: 'forOpA' is removed to simplify test output. A proper loop 142 // fusion pass should check the data dependence graph and run memref 143 // region analysis to ensure removing 'forOpA' is safe. 144 forOpA.erase(); 145 return true; 146 } 147 } 148 return false; 149 } 150 151 using LoopFunc = function_ref<bool(AffineForOp, AffineForOp, unsigned, unsigned, 152 unsigned, unsigned)>; 153 154 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'. 155 // If 'return_on_change' is true, returns on first invocation of 'fn' which 156 // returns true. 157 static bool iterateLoops(ArrayRef<SmallVector<AffineForOp, 2>> depthToLoops, 158 LoopFunc fn, bool returnOnChange = false) { 159 bool changed = false; 160 for (unsigned loopDepth = 0, end = depthToLoops.size(); loopDepth < end; 161 ++loopDepth) { 162 auto &loops = depthToLoops[loopDepth]; 163 unsigned numLoops = loops.size(); 164 for (unsigned j = 0; j < numLoops; ++j) { 165 for (unsigned k = 0; k < numLoops; ++k) { 166 if (j != k) 167 changed |= 168 fn(loops[j], loops[k], j, k, loopDepth, depthToLoops.size()); 169 if (changed && returnOnChange) 170 return true; 171 } 172 } 173 } 174 return changed; 175 } 176 177 void TestLoopFusion::runOnOperation() { 178 std::vector<SmallVector<AffineForOp, 2>> depthToLoops; 179 if (clTestLoopFusionTransformation) { 180 // Run loop fusion until a fixed point is reached. 181 do { 182 depthToLoops.clear(); 183 // Gather all AffineForOps by loop depth. 184 gatherLoops(getOperation(), depthToLoops); 185 186 // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'. 187 } while (iterateLoops(depthToLoops, testLoopFusionTransformation, 188 /*returnOnChange=*/true)); 189 return; 190 } 191 192 // Gather all AffineForOps by loop depth. 193 gatherLoops(getOperation(), depthToLoops); 194 195 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'. 196 if (clTestDependenceCheck) 197 iterateLoops(depthToLoops, testDependenceCheck); 198 if (clTestSliceComputation) 199 iterateLoops(depthToLoops, testSliceComputation); 200 } 201 202 namespace mlir { 203 namespace test { 204 void registerTestLoopFusion() { PassRegistration<TestLoopFusion>(); } 205 } // namespace test 206 } // namespace mlir 207