1 //===- TestLoopFusion.cpp - Test loop fusion ------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to test various loop fusion utility functions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/Analysis/Utils.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Affine/LoopFusionUtils.h" 16 #include "mlir/Dialect/Affine/LoopUtils.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Pass/Pass.h" 19 20 #define DEBUG_TYPE "test-loop-fusion" 21 22 using namespace mlir; 23 24 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); 25 26 static llvm::cl::opt<bool> clTestDependenceCheck( 27 "test-loop-fusion-dependence-check", 28 llvm::cl::desc("Enable testing of loop fusion dependence check"), 29 llvm::cl::cat(clOptionsCategory)); 30 31 static llvm::cl::opt<bool> clTestSliceComputation( 32 "test-loop-fusion-slice-computation", 33 llvm::cl::desc("Enable testing of loop fusion slice computation"), 34 llvm::cl::cat(clOptionsCategory)); 35 36 static llvm::cl::opt<bool> clTestLoopFusionTransformation( 37 "test-loop-fusion-transformation", 38 llvm::cl::desc("Enable testing of loop fusion transformation"), 39 llvm::cl::cat(clOptionsCategory)); 40 41 namespace { 42 43 struct TestLoopFusion 44 : public PassWrapper<TestLoopFusion, OperationPass<func::FuncOp>> { 45 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopFusion) 46 47 StringRef getArgument() const final { return "test-loop-fusion"; } 48 StringRef getDescription() const final { 49 return "Tests loop fusion utility functions."; 50 } 51 void runOnOperation() override; 52 }; 53 54 } // namespace 55 56 // Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths 57 // in range ['loopDepth' + 1, 'maxLoopDepth']. 58 // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists. 59 // Returns false as IR is not transformed. 60 static bool testDependenceCheck(AffineForOp srcForOp, AffineForOp dstForOp, 61 unsigned i, unsigned j, unsigned loopDepth, 62 unsigned maxLoopDepth) { 63 mlir::ComputationSliceState sliceUnion; 64 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) { 65 FusionResult result = 66 mlir::canFuseLoops(srcForOp, dstForOp, d, &sliceUnion); 67 if (result.value == FusionResult::FailBlockDependence) { 68 srcForOp->emitRemark("block-level dependence preventing" 69 " fusion of loop nest ") 70 << i << " into loop nest " << j << " at depth " << loopDepth; 71 } 72 } 73 return false; 74 } 75 76 // Returns the index of 'op' in its block. 77 static unsigned getBlockIndex(Operation &op) { 78 unsigned index = 0; 79 for (auto &opX : *op.getBlock()) { 80 if (&op == &opX) 81 break; 82 ++index; 83 } 84 return index; 85 } 86 87 // Returns a string representation of 'sliceUnion'. 88 static std::string getSliceStr(const mlir::ComputationSliceState &sliceUnion) { 89 std::string result; 90 llvm::raw_string_ostream os(result); 91 // Slice insertion point format [loop-depth, operation-block-index] 92 unsigned ipd = getNestingDepth(&*sliceUnion.insertPoint); 93 unsigned ipb = getBlockIndex(*sliceUnion.insertPoint); 94 os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb) 95 << ")"; 96 assert(sliceUnion.lbs.size() == sliceUnion.ubs.size()); 97 os << " loop bounds: "; 98 for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) { 99 os << '['; 100 sliceUnion.lbs[k].print(os); 101 os << ", "; 102 sliceUnion.ubs[k].print(os); 103 os << "] "; 104 } 105 return os.str(); 106 } 107 108 /// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths 109 /// in range ['loopDepth' + 1, 'maxLoopDepth']. 110 /// Emits a string representation of the slice union as a remark on 'loops[j]' 111 /// and marks this as incorrect slice if the slice is invalid. Returns false as 112 /// IR is not transformed. 113 static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB, 114 unsigned i, unsigned j, unsigned loopDepth, 115 unsigned maxLoopDepth) { 116 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) { 117 mlir::ComputationSliceState sliceUnion; 118 FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion); 119 if (result.value == FusionResult::Success) { 120 forOpB->emitRemark("slice (") 121 << " src loop: " << i << ", dst loop: " << j << ", depth: " << d 122 << " : " << getSliceStr(sliceUnion) << ")"; 123 } else if (result.value == FusionResult::FailIncorrectSlice) { 124 forOpB->emitRemark("Incorrect slice (") 125 << " src loop: " << i << ", dst loop: " << j << ", depth: " << d 126 << " : " << getSliceStr(sliceUnion) << ")"; 127 } 128 } 129 return false; 130 } 131 132 // Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range 133 // ['loopDepth' + 1, 'maxLoopDepth']. 134 // Returns true if loops were successfully fused, false otherwise. 135 static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB, 136 unsigned i, unsigned j, 137 unsigned loopDepth, 138 unsigned maxLoopDepth) { 139 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) { 140 mlir::ComputationSliceState sliceUnion; 141 FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion); 142 if (result.value == FusionResult::Success) { 143 mlir::fuseLoops(forOpA, forOpB, sliceUnion); 144 // Note: 'forOpA' is removed to simplify test output. A proper loop 145 // fusion pass should check the data dependence graph and run memref 146 // region analysis to ensure removing 'forOpA' is safe. 147 forOpA.erase(); 148 return true; 149 } 150 } 151 return false; 152 } 153 154 using LoopFunc = function_ref<bool(AffineForOp, AffineForOp, unsigned, unsigned, 155 unsigned, unsigned)>; 156 157 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'. 158 // If 'return_on_change' is true, returns on first invocation of 'fn' which 159 // returns true. 160 static bool iterateLoops(ArrayRef<SmallVector<AffineForOp, 2>> depthToLoops, 161 LoopFunc fn, bool returnOnChange = false) { 162 bool changed = false; 163 for (unsigned loopDepth = 0, end = depthToLoops.size(); loopDepth < end; 164 ++loopDepth) { 165 auto &loops = depthToLoops[loopDepth]; 166 unsigned numLoops = loops.size(); 167 for (unsigned j = 0; j < numLoops; ++j) { 168 for (unsigned k = 0; k < numLoops; ++k) { 169 if (j != k) 170 changed |= 171 fn(loops[j], loops[k], j, k, loopDepth, depthToLoops.size()); 172 if (changed && returnOnChange) 173 return true; 174 } 175 } 176 } 177 return changed; 178 } 179 180 void TestLoopFusion::runOnOperation() { 181 std::vector<SmallVector<AffineForOp, 2>> depthToLoops; 182 if (clTestLoopFusionTransformation) { 183 // Run loop fusion until a fixed point is reached. 184 do { 185 depthToLoops.clear(); 186 // Gather all AffineForOps by loop depth. 187 gatherLoops(getOperation(), depthToLoops); 188 189 // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'. 190 } while (iterateLoops(depthToLoops, testLoopFusionTransformation, 191 /*returnOnChange=*/true)); 192 return; 193 } 194 195 // Gather all AffineForOps by loop depth. 196 gatherLoops(getOperation(), depthToLoops); 197 198 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'. 199 if (clTestDependenceCheck) 200 iterateLoops(depthToLoops, testDependenceCheck); 201 if (clTestSliceComputation) 202 iterateLoops(depthToLoops, testSliceComputation); 203 } 204 205 namespace mlir { 206 namespace test { 207 void registerTestLoopFusion() { PassRegistration<TestLoopFusion>(); } 208 } // namespace test 209 } // namespace mlir 210