1 //===------ Simplify.cpp ----------------------------------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Simplify a SCoP by removing unnecessary statements and accesses.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "polly/Simplify.h"
15 #include "polly/ScopInfo.h"
16 #include "polly/ScopPass.h"
17 #include "polly/Support/GICHelper.h"
18 #include "polly/Support/ISLOStream.h"
19 #include "llvm/ADT/Statistic.h"
20 #include "llvm/Support/Debug.h"
21 #define DEBUG_TYPE "polly-simplify"
22 
23 using namespace llvm;
24 using namespace polly;
25 
26 namespace {
27 
28 STATISTIC(ScopsProcessed, "Number of SCoPs processed");
29 STATISTIC(ScopsModified, "Number of SCoPs simplified");
30 
31 STATISTIC(PairUnequalAccRels, "Number of Load-Store pairs NOT removed because "
32                               "of different access relations");
33 STATISTIC(InBetweenStore, "Number of Load-Store pairs NOT removed because "
34                           "there is another store between them");
35 STATISTIC(TotalOverwritesRemoved, "Number of removed overwritten writes");
36 STATISTIC(TotalRedundantWritesRemoved,
37           "Number of writes of same value removed in any SCoP");
38 STATISTIC(TotalStmtsRemoved, "Number of statements removed in any SCoP");
39 
40 static bool isImplicitRead(MemoryAccess *MA) {
41   return MA->isRead() && MA->isOriginalScalarKind();
42 }
43 
44 static bool isExplicitAccess(MemoryAccess *MA) {
45   return MA->isOriginalArrayKind();
46 }
47 
48 static bool isImplicitWrite(MemoryAccess *MA) {
49   return MA->isWrite() && MA->isOriginalScalarKind();
50 }
51 
52 /// Return a vector that contains MemoryAccesses in the order in
53 /// which they are executed.
54 ///
55 /// The order is:
56 /// - Implicit reads (BlockGenerator::generateScalarLoads)
57 /// - Explicit reads and writes (BlockGenerator::generateArrayLoad,
58 ///   BlockGenerator::generateArrayStore)
59 ///   - In block statements, the accesses are in order in which their
60 ///     instructions are executed.
61 ///   - In region statements, that order of execution is not predictable at
62 ///     compile-time.
63 /// - Implicit writes (BlockGenerator::generateScalarStores)
64 ///   The order in which implicit writes are executed relative to each other is
65 ///   undefined.
66 static SmallVector<MemoryAccess *, 32> getAccessesInOrder(ScopStmt &Stmt) {
67 
68   SmallVector<MemoryAccess *, 32> Accesses;
69 
70   for (MemoryAccess *MemAcc : Stmt)
71     if (isImplicitRead(MemAcc))
72       Accesses.push_back(MemAcc);
73 
74   for (MemoryAccess *MemAcc : Stmt)
75     if (isExplicitAccess(MemAcc))
76       Accesses.push_back(MemAcc);
77 
78   for (MemoryAccess *MemAcc : Stmt)
79     if (isImplicitWrite(MemAcc))
80       Accesses.push_back(MemAcc);
81 
82   return Accesses;
83 }
84 
85 class Simplify : public ScopPass {
86 private:
87   /// The last/current SCoP that is/has been processed.
88   Scop *S;
89 
90   /// Number of writes that are overwritten anyway.
91   int OverwritesRemoved = 0;
92 
93   /// Number of redundant writes removed from this SCoP.
94   int RedundantWritesRemoved = 0;
95 
96   /// Number of unnecessary statements removed from the SCoP.
97   int StmtsRemoved = 0;
98 
99   /// Return whether at least one simplification has been applied.
100   bool isModified() const {
101     return OverwritesRemoved > 0 || RedundantWritesRemoved > 0 ||
102            StmtsRemoved > 0;
103   }
104 
105   MemoryAccess *getReadAccessForValue(ScopStmt *Stmt, llvm::Value *Val) {
106     if (!isa<Instruction>(Val))
107       return nullptr;
108 
109     for (auto *MA : *Stmt) {
110       if (!MA->isRead())
111         continue;
112       if (MA->getAccessValue() != Val)
113         continue;
114 
115       return MA;
116     }
117 
118     return nullptr;
119   }
120 
121   /// Return a write access that occurs between @p From and @p To.
122   ///
123   /// In region statements the order is ignored because we cannot predict it.
124   ///
125   /// @param Stmt    Statement of both writes.
126   /// @param From    Start looking after this access.
127   /// @param To      Stop looking at this access, with the access itself.
128   /// @param Targets Look for an access that may wrote to one of these elements.
129   ///
130   /// @return A write access between @p From and @p To that writes to at least
131   ///         one element in @p Targets.
132   MemoryAccess *hasWriteBetween(ScopStmt *Stmt, MemoryAccess *From,
133                                 MemoryAccess *To, isl::map Targets) {
134     auto TargetsSpace = Targets.get_space();
135 
136     bool Started = Stmt->isRegionStmt();
137     auto Accesses = getAccessesInOrder(*Stmt);
138     for (auto *Acc : Accesses) {
139       if (Acc->isLatestScalarKind())
140         continue;
141 
142       if (Stmt->isBlockStmt() && From == Acc) {
143         assert(!Started);
144         Started = true;
145         continue;
146       }
147       if (Stmt->isBlockStmt() && To == Acc) {
148         assert(Started);
149         return nullptr;
150       }
151       if (!Started)
152         continue;
153 
154       if (!Acc->isWrite())
155         continue;
156 
157       auto AccRel = give(Acc->getAccessRelation());
158       auto AccRelSpace = AccRel.get_space();
159 
160       // Spaces being different means that they access different arrays.
161       if (!TargetsSpace.has_equal_tuples(AccRelSpace))
162         continue;
163 
164       AccRel = AccRel.intersect_domain(give(Acc->getStatement()->getDomain()));
165       AccRel = AccRel.intersect_params(give(S->getContext()));
166       auto CommonElt = Targets.intersect(AccRel);
167       if (!CommonElt.is_empty())
168         return Acc;
169     }
170     assert(Stmt->isRegionStmt() &&
171            "To must be encountered in block statements");
172     return nullptr;
173   }
174 
175   /// Remove writes that are overwritten unconditionally later in the same
176   /// statement.
177   ///
178   /// There must be no read of the same value between the write (that is to be
179   /// removed) and the overwrite.
180   void removeOverwrites() {
181     for (auto &Stmt : *S) {
182       auto Domain = give(Stmt.getDomain());
183       isl::union_map WillBeOverwritten =
184           isl::union_map::empty(give(S->getParamSpace()));
185 
186       SmallVector<MemoryAccess *, 32> Accesses(getAccessesInOrder(Stmt));
187 
188       // Iterate in reverse order, so the overwrite comes before the write that
189       // is to be removed.
190       for (auto *MA : reverse(Accesses)) {
191 
192         // In region statements, the explicit accesses can be in blocks that are
193         // can be executed in any order. We therefore process only the implicit
194         // writes and stop after that.
195         if (Stmt.isRegionStmt() && isExplicitAccess(MA))
196           break;
197 
198         auto AccRel = give(MA->getAccessRelation());
199         AccRel = AccRel.intersect_domain(Domain);
200         AccRel = AccRel.intersect_params(give(S->getContext()));
201 
202         // If a value is read in-between, do not consider it as overwritten.
203         if (MA->isRead()) {
204           WillBeOverwritten = WillBeOverwritten.subtract(AccRel);
205           continue;
206         }
207 
208         // If all of a write's elements are overwritten, remove it.
209         isl::union_map AccRelUnion = AccRel;
210         if (AccRelUnion.is_subset(WillBeOverwritten)) {
211           DEBUG(dbgs() << "Removing " << MA
212                        << " which will be overwritten anyway\n");
213 
214           Stmt.removeSingleMemoryAccess(MA);
215           OverwritesRemoved++;
216           TotalOverwritesRemoved++;
217         }
218 
219         // Unconditional writes overwrite other values.
220         if (MA->isMustWrite())
221           WillBeOverwritten = WillBeOverwritten.add_map(AccRel);
222       }
223     }
224   }
225 
226   /// Remove writes that just write the same value already stored in the
227   /// element.
228   void removeRedundantWrites() {
229     // Delay actual removal to not invalidate iterators.
230     SmallVector<MemoryAccess *, 8> StoresToRemove;
231 
232     for (auto &Stmt : *S) {
233       for (auto *WA : Stmt) {
234         if (!WA->isMustWrite())
235           continue;
236         if (!WA->isLatestArrayKind())
237           continue;
238         if (!isa<StoreInst>(WA->getAccessInstruction()))
239           continue;
240 
241         auto ReadingValue = WA->getAccessValue();
242         if (!ReadingValue)
243           continue;
244 
245         auto RA = getReadAccessForValue(&Stmt, ReadingValue);
246         if (!RA)
247           continue;
248         if (!RA->isLatestArrayKind())
249           continue;
250 
251         auto WARel = give(WA->getLatestAccessRelation());
252         WARel = WARel.intersect_domain(give(WA->getStatement()->getDomain()));
253         WARel = WARel.intersect_params(give(S->getContext()));
254         auto RARel = give(RA->getLatestAccessRelation());
255         RARel = RARel.intersect_domain(give(RA->getStatement()->getDomain()));
256         RARel = RARel.intersect_params(give(S->getContext()));
257 
258         if (!RARel.is_equal(WARel)) {
259           PairUnequalAccRels++;
260           DEBUG(dbgs() << "Not cleaning up " << WA
261                        << " because of unequal access relations:\n");
262           DEBUG(dbgs() << "      RA: " << RARel << "\n");
263           DEBUG(dbgs() << "      WA: " << WARel << "\n");
264           continue;
265         }
266 
267         if (auto *Conflicting = hasWriteBetween(&Stmt, RA, WA, WARel)) {
268           (void)Conflicting;
269           InBetweenStore++;
270           DEBUG(dbgs() << "Not cleaning up " << WA
271                        << " because there is another store to the same element "
272                           "between\n");
273           DEBUG(Conflicting->print(dbgs()));
274           continue;
275         }
276 
277         StoresToRemove.push_back(WA);
278       }
279     }
280 
281     for (auto *WA : StoresToRemove) {
282       auto Stmt = WA->getStatement();
283       auto AccRel = give(WA->getAccessRelation());
284       auto AccVal = WA->getAccessValue();
285 
286       DEBUG(dbgs() << "Cleanup of " << WA << ":\n");
287       DEBUG(dbgs() << "      Scalar: " << *AccVal << "\n");
288       DEBUG(dbgs() << "      AccRel: " << AccRel << "\n");
289       (void)AccVal;
290       (void)AccRel;
291 
292       Stmt->removeSingleMemoryAccess(WA);
293 
294       RedundantWritesRemoved++;
295       TotalRedundantWritesRemoved++;
296     }
297   }
298 
299   /// Remove statements without side effects.
300   void removeUnnecessayStmts() {
301     auto NumStmtsBefore = S->getSize();
302     S->simplifySCoP(true);
303     assert(NumStmtsBefore >= S->getSize());
304     StmtsRemoved = NumStmtsBefore - S->getSize();
305     DEBUG(dbgs() << "Removed " << StmtsRemoved << " (of " << NumStmtsBefore
306                  << ") statements\n");
307     TotalStmtsRemoved += StmtsRemoved;
308   }
309 
310   /// Print simplification statistics to @p OS.
311   void printStatistics(llvm::raw_ostream &OS, int Indent = 0) const {
312     OS.indent(Indent) << "Statistics {\n";
313     OS.indent(Indent + 4) << "Overwrites removed: " << OverwritesRemoved
314                           << '\n';
315     OS.indent(Indent + 4) << "Redundant writes removed: "
316                           << RedundantWritesRemoved << "\n";
317     OS.indent(Indent + 4) << "Stmts removed: " << StmtsRemoved << "\n";
318     OS.indent(Indent) << "}\n";
319   }
320 
321   /// Print the current state of all MemoryAccesses to @p OS.
322   void printAccesses(llvm::raw_ostream &OS, int Indent = 0) const {
323     OS.indent(Indent) << "After accesses {\n";
324     for (auto &Stmt : *S) {
325       OS.indent(Indent + 4) << Stmt.getBaseName() << "\n";
326       for (auto *MA : Stmt)
327         MA->print(OS);
328     }
329     OS.indent(Indent) << "}\n";
330   }
331 
332 public:
333   static char ID;
334   explicit Simplify() : ScopPass(ID) {}
335 
336   virtual void getAnalysisUsage(AnalysisUsage &AU) const override {
337     AU.addRequiredTransitive<ScopInfoRegionPass>();
338     AU.setPreservesAll();
339   }
340 
341   virtual bool runOnScop(Scop &S) override {
342     // Reset statistics of last processed SCoP.
343     releaseMemory();
344 
345     // Prepare processing of this SCoP.
346     this->S = &S;
347     ScopsProcessed++;
348 
349     DEBUG(dbgs() << "Removing overwrites...\n");
350     removeOverwrites();
351 
352     DEBUG(dbgs() << "Removing redundant writes...\n");
353     removeRedundantWrites();
354 
355     DEBUG(dbgs() << "Removing statements without side effects...\n");
356     removeUnnecessayStmts();
357 
358     if (isModified())
359       ScopsModified++;
360     DEBUG(dbgs() << "\nFinal Scop:\n");
361     DEBUG(S.print(dbgs()));
362 
363     return false;
364   }
365 
366   virtual void printScop(raw_ostream &OS, Scop &S) const override {
367     assert(&S == this->S &&
368            "Can only print analysis for the last processed SCoP");
369     printStatistics(OS);
370 
371     if (!isModified()) {
372       OS << "SCoP could not be simplified\n";
373       return;
374     }
375     printAccesses(OS);
376   }
377 
378   virtual void releaseMemory() override {
379     S = nullptr;
380 
381     OverwritesRemoved = 0;
382     RedundantWritesRemoved = 0;
383     StmtsRemoved = 0;
384   }
385 };
386 
387 char Simplify::ID;
388 } // anonymous namespace
389 
390 Pass *polly::createSimplifyPass() { return new Simplify(); }
391 
392 INITIALIZE_PASS_BEGIN(Simplify, "polly-simplify", "Polly - Simplify", false,
393                       false)
394 INITIALIZE_PASS_END(Simplify, "polly-simplify", "Polly - Simplify", false,
395                     false)
396