1 //===- TestMemRefDependenceCheck.cpp - Test dep analysis ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to run pair-wise memref access dependence checks. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/AffineAnalysis.h" 14 #include "mlir/Analysis/AffineStructures.h" 15 #include "mlir/Analysis/Utils.h" 16 #include "mlir/Dialect/Affine/IR/AffineOps.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/Pass/Pass.h" 19 #include "llvm/Support/Debug.h" 20 21 #define DEBUG_TYPE "test-memref-dependence-check" 22 23 using namespace mlir; 24 25 namespace { 26 27 // TODO: Add common surrounding loop depth-wise dependence checks. 28 /// Checks dependences between all pairs of memref accesses in a Function. 29 struct TestMemRefDependenceCheck 30 : public PassWrapper<TestMemRefDependenceCheck, FunctionPass> { 31 SmallVector<Operation *, 4> loadsAndStores; 32 void runOnFunction() override; 33 }; 34 35 } // end anonymous namespace 36 37 // Returns a result string which represents the direction vector (if there was 38 // a dependence), returns the string "false" otherwise. 39 static std::string 40 getDirectionVectorStr(bool ret, unsigned numCommonLoops, unsigned loopNestDepth, 41 ArrayRef<DependenceComponent> dependenceComponents) { 42 if (!ret) 43 return "false"; 44 if (dependenceComponents.empty() || loopNestDepth > numCommonLoops) 45 return "true"; 46 std::string result; 47 for (unsigned i = 0, e = dependenceComponents.size(); i < e; ++i) { 48 std::string lbStr = "-inf"; 49 if (dependenceComponents[i].lb.hasValue() && 50 dependenceComponents[i].lb.getValue() != 51 std::numeric_limits<int64_t>::min()) 52 lbStr = std::to_string(dependenceComponents[i].lb.getValue()); 53 54 std::string ubStr = "+inf"; 55 if (dependenceComponents[i].ub.hasValue() && 56 dependenceComponents[i].ub.getValue() != 57 std::numeric_limits<int64_t>::max()) 58 ubStr = std::to_string(dependenceComponents[i].ub.getValue()); 59 60 result += "[" + lbStr + ", " + ubStr + "]"; 61 } 62 return result; 63 } 64 65 // For each access in 'loadsAndStores', runs a dependence check between this 66 // "source" access and all subsequent "destination" accesses in 67 // 'loadsAndStores'. Emits the result of the dependence check as a note with 68 // the source access. 69 static void checkDependences(ArrayRef<Operation *> loadsAndStores) { 70 for (unsigned i = 0, e = loadsAndStores.size(); i < e; ++i) { 71 auto *srcOpInst = loadsAndStores[i]; 72 MemRefAccess srcAccess(srcOpInst); 73 for (unsigned j = 0; j < e; ++j) { 74 auto *dstOpInst = loadsAndStores[j]; 75 MemRefAccess dstAccess(dstOpInst); 76 77 unsigned numCommonLoops = 78 getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst); 79 for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { 80 FlatAffineConstraints dependenceConstraints; 81 SmallVector<DependenceComponent, 2> dependenceComponents; 82 DependenceResult result = checkMemrefAccessDependence( 83 srcAccess, dstAccess, d, &dependenceConstraints, 84 &dependenceComponents); 85 assert(result.value != DependenceResult::Failure); 86 bool ret = hasDependence(result); 87 // TODO: Print dependence type (i.e. RAW, etc) and print 88 // distance vectors as: ([2, 3], [0, 10]). Also, shorten distance 89 // vectors from ([1, 1], [3, 3]) to (1, 3). 90 srcOpInst->emitRemark("dependence from ") 91 << i << " to " << j << " at depth " << d << " = " 92 << getDirectionVectorStr(ret, numCommonLoops, d, 93 dependenceComponents); 94 } 95 } 96 } 97 } 98 99 // Walks the Function 'f' adding load and store ops to 'loadsAndStores'. 100 // Runs pair-wise dependence checks. 101 void TestMemRefDependenceCheck::runOnFunction() { 102 // Collect the loads and stores within the function. 103 loadsAndStores.clear(); 104 getFunction().walk([&](Operation *op) { 105 if (isa<AffineLoadOp, AffineStoreOp>(op)) 106 loadsAndStores.push_back(op); 107 }); 108 109 checkDependences(loadsAndStores); 110 } 111 112 namespace mlir { 113 namespace test { 114 void registerTestMemRefDependenceCheck() { 115 PassRegistration<TestMemRefDependenceCheck> pass( 116 "test-memref-dependence-check", 117 "Checks dependences between all pairs of memref accesses."); 118 } 119 } // namespace test 120 } // namespace mlir 121