1 //===- TestLoopFusion.cpp - Test loop fusion ------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to test various loop fusion utility functions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/Analysis/Utils.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Affine/LoopFusionUtils.h" 16 #include "mlir/Dialect/Affine/LoopUtils.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 #include "mlir/Pass/Pass.h" 19 20 #define DEBUG_TYPE "test-loop-fusion" 21 22 using namespace mlir; 23 24 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); 25 26 static llvm::cl::opt<bool> clTestDependenceCheck( 27 "test-loop-fusion-dependence-check", 28 llvm::cl::desc("Enable testing of loop fusion dependence check"), 29 llvm::cl::cat(clOptionsCategory)); 30 31 static llvm::cl::opt<bool> clTestSliceComputation( 32 "test-loop-fusion-slice-computation", 33 llvm::cl::desc("Enable testing of loop fusion slice computation"), 34 llvm::cl::cat(clOptionsCategory)); 35 36 static llvm::cl::opt<bool> clTestLoopFusionTransformation( 37 "test-loop-fusion-transformation", 38 llvm::cl::desc("Enable testing of loop fusion transformation"), 39 llvm::cl::cat(clOptionsCategory)); 40 41 namespace { 42 43 struct TestLoopFusion 44 : public PassWrapper<TestLoopFusion, OperationPass<FuncOp>> { 45 StringRef getArgument() const final { return "test-loop-fusion"; } 46 StringRef getDescription() const final { 47 return "Tests loop fusion utility functions."; 48 } 49 void runOnOperation() override; 50 }; 51 52 } // namespace 53 54 // Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths 55 // in range ['loopDepth' + 1, 'maxLoopDepth']. 56 // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists. 57 // Returns false as IR is not transformed. 58 static bool testDependenceCheck(AffineForOp srcForOp, AffineForOp dstForOp, 59 unsigned i, unsigned j, unsigned loopDepth, 60 unsigned maxLoopDepth) { 61 mlir::ComputationSliceState sliceUnion; 62 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) { 63 FusionResult result = 64 mlir::canFuseLoops(srcForOp, dstForOp, d, &sliceUnion); 65 if (result.value == FusionResult::FailBlockDependence) { 66 srcForOp->emitRemark("block-level dependence preventing" 67 " fusion of loop nest ") 68 << i << " into loop nest " << j << " at depth " << loopDepth; 69 } 70 } 71 return false; 72 } 73 74 // Returns the index of 'op' in its block. 75 static unsigned getBlockIndex(Operation &op) { 76 unsigned index = 0; 77 for (auto &opX : *op.getBlock()) { 78 if (&op == &opX) 79 break; 80 ++index; 81 } 82 return index; 83 } 84 85 // Returns a string representation of 'sliceUnion'. 86 static std::string getSliceStr(const mlir::ComputationSliceState &sliceUnion) { 87 std::string result; 88 llvm::raw_string_ostream os(result); 89 // Slice insertion point format [loop-depth, operation-block-index] 90 unsigned ipd = getNestingDepth(&*sliceUnion.insertPoint); 91 unsigned ipb = getBlockIndex(*sliceUnion.insertPoint); 92 os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb) 93 << ")"; 94 assert(sliceUnion.lbs.size() == sliceUnion.ubs.size()); 95 os << " loop bounds: "; 96 for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) { 97 os << '['; 98 sliceUnion.lbs[k].print(os); 99 os << ", "; 100 sliceUnion.ubs[k].print(os); 101 os << "] "; 102 } 103 return os.str(); 104 } 105 106 /// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths 107 /// in range ['loopDepth' + 1, 'maxLoopDepth']. 108 /// Emits a string representation of the slice union as a remark on 'loops[j]' 109 /// and marks this as incorrect slice if the slice is invalid. Returns false as 110 /// IR is not transformed. 111 static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB, 112 unsigned i, unsigned j, unsigned loopDepth, 113 unsigned maxLoopDepth) { 114 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) { 115 mlir::ComputationSliceState sliceUnion; 116 FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion); 117 if (result.value == FusionResult::Success) { 118 forOpB->emitRemark("slice (") 119 << " src loop: " << i << ", dst loop: " << j << ", depth: " << d 120 << " : " << getSliceStr(sliceUnion) << ")"; 121 } else if (result.value == FusionResult::FailIncorrectSlice) { 122 forOpB->emitRemark("Incorrect slice (") 123 << " src loop: " << i << ", dst loop: " << j << ", depth: " << d 124 << " : " << getSliceStr(sliceUnion) << ")"; 125 } 126 } 127 return false; 128 } 129 130 // Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range 131 // ['loopDepth' + 1, 'maxLoopDepth']. 132 // Returns true if loops were successfully fused, false otherwise. 133 static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB, 134 unsigned i, unsigned j, 135 unsigned loopDepth, 136 unsigned maxLoopDepth) { 137 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) { 138 mlir::ComputationSliceState sliceUnion; 139 FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion); 140 if (result.value == FusionResult::Success) { 141 mlir::fuseLoops(forOpA, forOpB, sliceUnion); 142 // Note: 'forOpA' is removed to simplify test output. A proper loop 143 // fusion pass should check the data dependence graph and run memref 144 // region analysis to ensure removing 'forOpA' is safe. 145 forOpA.erase(); 146 return true; 147 } 148 } 149 return false; 150 } 151 152 using LoopFunc = function_ref<bool(AffineForOp, AffineForOp, unsigned, unsigned, 153 unsigned, unsigned)>; 154 155 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'. 156 // If 'return_on_change' is true, returns on first invocation of 'fn' which 157 // returns true. 158 static bool iterateLoops(ArrayRef<SmallVector<AffineForOp, 2>> depthToLoops, 159 LoopFunc fn, bool returnOnChange = false) { 160 bool changed = false; 161 for (unsigned loopDepth = 0, end = depthToLoops.size(); loopDepth < end; 162 ++loopDepth) { 163 auto &loops = depthToLoops[loopDepth]; 164 unsigned numLoops = loops.size(); 165 for (unsigned j = 0; j < numLoops; ++j) { 166 for (unsigned k = 0; k < numLoops; ++k) { 167 if (j != k) 168 changed |= 169 fn(loops[j], loops[k], j, k, loopDepth, depthToLoops.size()); 170 if (changed && returnOnChange) 171 return true; 172 } 173 } 174 } 175 return changed; 176 } 177 178 void TestLoopFusion::runOnOperation() { 179 std::vector<SmallVector<AffineForOp, 2>> depthToLoops; 180 if (clTestLoopFusionTransformation) { 181 // Run loop fusion until a fixed point is reached. 182 do { 183 depthToLoops.clear(); 184 // Gather all AffineForOps by loop depth. 185 gatherLoops(getOperation(), depthToLoops); 186 187 // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'. 188 } while (iterateLoops(depthToLoops, testLoopFusionTransformation, 189 /*returnOnChange=*/true)); 190 return; 191 } 192 193 // Gather all AffineForOps by loop depth. 194 gatherLoops(getOperation(), depthToLoops); 195 196 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'. 197 if (clTestDependenceCheck) 198 iterateLoops(depthToLoops, testDependenceCheck); 199 if (clTestSliceComputation) 200 iterateLoops(depthToLoops, testSliceComputation); 201 } 202 203 namespace mlir { 204 namespace test { 205 void registerTestLoopFusion() { PassRegistration<TestLoopFusion>(); } 206 } // namespace test 207 } // namespace mlir 208