1 //===------ Simplify.cpp ----------------------------------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Simplify a SCoP by removing unnecessary statements and accesses.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "polly/Simplify.h"
15 #include "polly/ScopInfo.h"
16 #include "polly/ScopPass.h"
17 #include "polly/Support/GICHelper.h"
18 #include "llvm/ADT/Statistic.h"
19 #include "llvm/Support/Debug.h"
20 #define DEBUG_TYPE "polly-simplify"
21 
22 using namespace llvm;
23 using namespace polly;
24 
25 namespace {
26 
27 STATISTIC(ScopsProcessed, "Number of SCoPs processed");
28 STATISTIC(ScopsModified, "Number of SCoPs simplified");
29 
30 STATISTIC(PairUnequalAccRels, "Number of Load-Store pairs NOT removed because "
31                               "of different access relations");
32 STATISTIC(InBetweenStore, "Number of Load-Store pairs NOT removed because "
33                           "there is another store between them");
34 STATISTIC(TotalRedundantWritesRemoved,
35           "Number of writes of same value removed in any SCoP");
36 STATISTIC(TotalStmtsRemoved, "Number of statements removed in any SCoP");
37 
38 class Simplify : public ScopPass {
39 private:
40   /// The last/current SCoP that is/has been processed.
41   Scop *S;
42 
43   /// Number of redundant writes removed from this SCoP.
44   int RedundantWritesRemoved = 0;
45 
46   /// Number of unnecessary statements removed from the SCoP.
47   int StmtsRemoved = 0;
48 
49   /// Return whether at least one simplification has been applied.
50   bool isModified() const {
51     return RedundantWritesRemoved > 0 || StmtsRemoved > 0;
52   }
53 
54   MemoryAccess *getReadAccessForValue(ScopStmt *Stmt, llvm::Value *Val) {
55     if (!isa<Instruction>(Val))
56       return nullptr;
57 
58     for (auto *MA : *Stmt) {
59       if (!MA->isRead())
60         continue;
61       if (MA->getAccessValue() != Val)
62         continue;
63 
64       return MA;
65     }
66 
67     return nullptr;
68   }
69 
70   /// Return a write access that occurs between @p From and @p To.
71   ///
72   /// In region statements the order is ignored because we cannot predict it.
73   ///
74   /// @param Stmt    Statement of both writes.
75   /// @param From    Start looking after this access.
76   /// @param To      Stop looking at this access, with the access itself.
77   /// @param Targets Look for an access that may wrote to one of these elements.
78   ///
79   /// @return A write access between @p From and @p To that writes to at least
80   ///         one element in @p Targets.
81   MemoryAccess *hasWriteBetween(ScopStmt *Stmt, MemoryAccess *From,
82                                 MemoryAccess *To, isl::map Targets) {
83     auto TargetsSpace = give(isl_map_get_space(Targets.keep()));
84 
85     bool Started = Stmt->isRegionStmt();
86     for (auto *Acc : *Stmt) {
87       if (Acc->isLatestScalarKind())
88         continue;
89 
90       if (Stmt->isBlockStmt() && From == Acc) {
91         assert(!Started);
92         Started = true;
93         continue;
94       }
95       if (Stmt->isBlockStmt() && To == Acc) {
96         assert(Started);
97         return nullptr;
98       }
99       if (!Started)
100         continue;
101 
102       if (!Acc->isWrite())
103         continue;
104 
105       auto AccRel = give(Acc->getAccessRelation());
106       auto AccRelSpace = give(isl_map_get_space(AccRel.keep()));
107 
108       // Spaces being different means that they access different arrays.
109       if (isl_space_has_equal_tuples(TargetsSpace.keep(), AccRelSpace.keep()) ==
110           isl_bool_false)
111         continue;
112 
113       AccRel = give(isl_map_intersect_domain(AccRel.take(),
114                                              Acc->getStatement()->getDomain()));
115       AccRel = give(isl_map_intersect_params(AccRel.take(), S->getContext()));
116       auto CommonElt = give(isl_map_intersect(Targets.copy(), AccRel.copy()));
117       if (isl_map_is_empty(CommonElt.keep()) != isl_bool_true)
118         return Acc;
119     }
120     assert(Stmt->isRegionStmt() &&
121            "To must be encountered in block statements");
122     return nullptr;
123   }
124 
125   /// Remove writes that just write the same value already stored in the
126   /// element.
127   void removeRedundantWrites() {
128     // Delay actual removal to not invalidate iterators.
129     SmallVector<MemoryAccess *, 8> StoresToRemove;
130 
131     for (auto &Stmt : *S) {
132       for (auto *WA : Stmt) {
133         if (!WA->isMustWrite())
134           continue;
135         if (!WA->isLatestArrayKind())
136           continue;
137         if (!isa<StoreInst>(WA->getAccessInstruction()))
138           continue;
139 
140         auto ReadingValue = WA->getAccessValue();
141         if (!ReadingValue)
142           continue;
143 
144         auto RA = getReadAccessForValue(&Stmt, ReadingValue);
145         if (!RA)
146           continue;
147         if (!RA->isLatestArrayKind())
148           continue;
149 
150         auto WARel = give(WA->getLatestAccessRelation());
151         WARel = give(isl_map_intersect_domain(WARel.take(),
152                                               WA->getStatement()->getDomain()));
153         WARel = give(isl_map_intersect_params(WARel.take(), S->getContext()));
154         auto RARel = give(RA->getLatestAccessRelation());
155         RARel = give(isl_map_intersect_domain(RARel.take(),
156                                               RA->getStatement()->getDomain()));
157         RARel = give(isl_map_intersect_params(RARel.take(), S->getContext()));
158 
159         if (isl_map_is_equal(RARel.keep(), WARel.keep()) != isl_bool_true) {
160           PairUnequalAccRels++;
161           DEBUG(dbgs() << "Not cleaning up " << WA
162                        << " because of unequal access relations:\n");
163           DEBUG(dbgs() << "      RA: " << RARel << "\n");
164           DEBUG(dbgs() << "      WA: " << WARel << "\n");
165           continue;
166         }
167 
168         if (auto *Conflicting = hasWriteBetween(&Stmt, RA, WA, WARel)) {
169           InBetweenStore++;
170           DEBUG(dbgs() << "Not cleaning up " << WA
171                        << " because there is another store to the same element "
172                           "between\n");
173           DEBUG(Conflicting->print(dbgs()));
174           continue;
175         }
176 
177         StoresToRemove.push_back(WA);
178       }
179     }
180 
181     for (auto *WA : StoresToRemove) {
182       auto Stmt = WA->getStatement();
183       auto AccRel = give(WA->getAccessRelation());
184       auto AccVal = WA->getAccessValue();
185 
186       DEBUG(dbgs() << "Cleanup of " << WA << ":\n");
187       DEBUG(dbgs() << "      Scalar: " << *AccVal << "\n");
188       DEBUG(dbgs() << "      AccRel: " << AccRel << "\n");
189 
190       Stmt->removeSingleMemoryAccess(WA);
191 
192       RedundantWritesRemoved++;
193       TotalRedundantWritesRemoved++;
194     }
195   }
196 
197   /// Remove statements without side effects.
198   void removeUnnecessayStmts() {
199     auto NumStmtsBefore = S->getSize();
200     S->simplifySCoP(true);
201     assert(NumStmtsBefore >= S->getSize());
202     StmtsRemoved = NumStmtsBefore - S->getSize();
203     DEBUG(dbgs() << "Removed " << StmtsRemoved << " (of " << NumStmtsBefore
204                  << ") statements\n");
205     TotalStmtsRemoved += StmtsRemoved;
206   }
207 
208   /// Print simplification statistics to @p OS.
209   void printStatistics(llvm::raw_ostream &OS, int Indent = 0) const {
210     OS.indent(Indent) << "Statistics {\n";
211     OS.indent(Indent + 4) << "Redundant writes removed: "
212                           << RedundantWritesRemoved << "\n";
213     OS.indent(Indent + 4) << "Stmts removed: " << StmtsRemoved << "\n";
214     OS.indent(Indent) << "}\n";
215   }
216 
217   /// Print the current state of all MemoryAccesses to @p OS.
218   void printAccesses(llvm::raw_ostream &OS, int Indent = 0) const {
219     OS.indent(Indent) << "After accesses {\n";
220     for (auto &Stmt : *S) {
221       OS.indent(Indent + 4) << Stmt.getBaseName() << "\n";
222       for (auto *MA : Stmt)
223         MA->print(OS);
224     }
225     OS.indent(Indent) << "}\n";
226   }
227 
228 public:
229   static char ID;
230   explicit Simplify() : ScopPass(ID) {}
231 
232   virtual void getAnalysisUsage(AnalysisUsage &AU) const override {
233     AU.addRequiredTransitive<ScopInfoRegionPass>();
234     AU.setPreservesAll();
235   }
236 
237   virtual bool runOnScop(Scop &S) override {
238     // Reset statistics of last processed SCoP.
239     releaseMemory();
240 
241     // Prepare processing of this SCoP.
242     this->S = &S;
243     ScopsProcessed++;
244 
245     DEBUG(dbgs() << "Removing redundant writes...\n");
246     removeRedundantWrites();
247 
248     DEBUG(dbgs() << "Removing statements without side effects...\n");
249     removeUnnecessayStmts();
250 
251     if (isModified())
252       ScopsModified++;
253     DEBUG(dbgs() << "\nFinal Scop:\n");
254     DEBUG(S.print(dbgs()));
255 
256     return false;
257   }
258 
259   virtual void printScop(raw_ostream &OS, Scop &S) const override {
260     assert(&S == this->S &&
261            "Can only print analysis for the last processed SCoP");
262     printStatistics(OS);
263 
264     if (!isModified()) {
265       OS << "SCoP could not be simplified\n";
266       return;
267     }
268     printAccesses(OS);
269   }
270 
271   virtual void releaseMemory() override {
272     S = nullptr;
273     StmtsRemoved = 0;
274   }
275 };
276 
277 char Simplify::ID;
278 } // anonymous namespace
279 
280 Pass *polly::createSimplifyPass() { return new Simplify(); }
281 
282 INITIALIZE_PASS_BEGIN(Simplify, "polly-simplify", "Polly - Simplify", false,
283                       false)
284 INITIALIZE_PASS_END(Simplify, "polly-simplify", "Polly - Simplify", false,
285                     false)
286