1 //===- ASTDiff.cpp - AST differencing implementation-----------*- C++ -*- -===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file contains definitons for the AST differencing interface.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "clang/Tooling/ASTDiff/ASTDiff.h"
15 
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/Lex/Lexer.h"
18 #include "llvm/ADT/PriorityQueue.h"
19 
20 #include <limits>
21 #include <memory>
22 #include <unordered_set>
23 
24 using namespace llvm;
25 using namespace clang;
26 
27 namespace clang {
28 namespace diff {
29 
30 namespace {
31 /// Maps nodes of the left tree to ones on the right, and vice versa.
32 class Mapping {
33 public:
34   Mapping() = default;
35   Mapping(Mapping &&Other) = default;
36   Mapping &operator=(Mapping &&Other) = default;
37 
38   Mapping(size_t Size) {
39     SrcToDst = llvm::make_unique<NodeId[]>(Size);
40     DstToSrc = llvm::make_unique<NodeId[]>(Size);
41   }
42 
43   void link(NodeId Src, NodeId Dst) {
44     SrcToDst[Src] = Dst, DstToSrc[Dst] = Src;
45   }
46 
47   NodeId getDst(NodeId Src) const { return SrcToDst[Src]; }
48   NodeId getSrc(NodeId Dst) const { return DstToSrc[Dst]; }
49   bool hasSrc(NodeId Src) const { return getDst(Src).isValid(); }
50   bool hasDst(NodeId Dst) const { return getSrc(Dst).isValid(); }
51 
52 private:
53   std::unique_ptr<NodeId[]> SrcToDst, DstToSrc;
54 };
55 } // end anonymous namespace
56 
57 class ASTDiff::Impl {
58 public:
59   SyntaxTree::Impl &T1, &T2;
60   Mapping TheMapping;
61 
62   Impl(SyntaxTree::Impl &T1, SyntaxTree::Impl &T2,
63        const ComparisonOptions &Options);
64 
65   /// Matches nodes one-by-one based on their similarity.
66   void computeMapping();
67 
68   // Compute Change for each node based on similarity.
69   void computeChangeKinds(Mapping &M);
70 
71   NodeId getMapped(const std::unique_ptr<SyntaxTree::Impl> &Tree,
72                    NodeId Id) const {
73     if (&*Tree == &T1)
74       return TheMapping.getDst(Id);
75     assert(&*Tree == &T2 && "Invalid tree.");
76     return TheMapping.getSrc(Id);
77   }
78 
79 private:
80   // Returns true if the two subtrees are identical.
81   bool identical(NodeId Id1, NodeId Id2) const;
82 
83   // Returns false if the nodes must not be mached.
84   bool isMatchingPossible(NodeId Id1, NodeId Id2) const;
85 
86   // Returns true if the nodes' parents are matched.
87   bool haveSameParents(const Mapping &M, NodeId Id1, NodeId Id2) const;
88 
89   // Uses an optimal albeit slow algorithm to compute a mapping between two
90   // subtrees, but only if both have fewer nodes than MaxSize.
91   void addOptimalMapping(Mapping &M, NodeId Id1, NodeId Id2) const;
92 
93   // Computes the ratio of common descendants between the two nodes.
94   // Descendants are only considered to be equal when they are mapped in M.
95   double getSimilarity(const Mapping &M, NodeId Id1, NodeId Id2) const;
96 
97   // Returns the node that has the highest degree of similarity.
98   NodeId findCandidate(const Mapping &M, NodeId Id1) const;
99 
100   // Returns a mapping of identical subtrees.
101   Mapping matchTopDown() const;
102 
103   // Tries to match any yet unmapped nodes, in a bottom-up fashion.
104   void matchBottomUp(Mapping &M) const;
105 
106   const ComparisonOptions &Options;
107 
108   friend class ZhangShashaMatcher;
109 };
110 
111 /// Represents the AST of a TranslationUnit.
112 class SyntaxTree::Impl {
113 public:
114   /// Constructs a tree from the entire translation unit.
115   Impl(SyntaxTree *Parent, const ASTContext &AST);
116   /// Constructs a tree from an AST node.
117   Impl(SyntaxTree *Parent, Decl *N, const ASTContext &AST);
118   Impl(SyntaxTree *Parent, Stmt *N, const ASTContext &AST);
119   template <class T>
120   Impl(SyntaxTree *Parent,
121        typename std::enable_if<std::is_base_of<Stmt, T>::value, T>::type *Node,
122        const ASTContext &AST)
123       : Impl(Parent, dyn_cast<Stmt>(Node), AST) {}
124   template <class T>
125   Impl(SyntaxTree *Parent,
126        typename std::enable_if<std::is_base_of<Decl, T>::value, T>::type *Node,
127        const ASTContext &AST)
128       : Impl(Parent, dyn_cast<Decl>(Node), AST) {}
129 
130   SyntaxTree *Parent;
131   const ASTContext &AST;
132   std::vector<NodeId> Leaves;
133   // Maps preorder indices to postorder ones.
134   std::vector<int> PostorderIds;
135   std::vector<NodeId> NodesBfs;
136 
137   int getSize() const { return Nodes.size(); }
138   NodeId getRootId() const { return 0; }
139   PreorderIterator begin() const { return getRootId(); }
140   PreorderIterator end() const { return getSize(); }
141 
142   const Node &getNode(NodeId Id) const { return Nodes[Id]; }
143   Node &getMutableNode(NodeId Id) { return Nodes[Id]; }
144   bool isValidNodeId(NodeId Id) const { return Id >= 0 && Id < getSize(); }
145   void addNode(Node &N) { Nodes.push_back(N); }
146   int getNumberOfDescendants(NodeId Id) const;
147   bool isInSubtree(NodeId Id, NodeId SubtreeRoot) const;
148   int findPositionInParent(NodeId Id, bool Shifted = false) const;
149 
150   std::string getNodeValue(NodeId Id) const;
151   std::string getNodeValue(const Node &Node) const;
152 
153 private:
154   /// Nodes in preorder.
155   std::vector<Node> Nodes;
156 
157   void initTree();
158   void setLeftMostDescendants();
159 };
160 
161 template <class T>
162 static bool isNodeExcluded(const SourceManager &SrcMgr, T *N) {
163   if (!N)
164     return true;
165   SourceLocation SLoc = N->getLocStart();
166   return SLoc.isValid() && SrcMgr.isInSystemHeader(SLoc);
167 }
168 
169 namespace {
170 /// Counts the number of nodes that will be compared.
171 struct NodeCountVisitor : public RecursiveASTVisitor<NodeCountVisitor> {
172   int Count = 0;
173   const SyntaxTree::Impl &Tree;
174   NodeCountVisitor(const SyntaxTree::Impl &Tree) : Tree(Tree) {}
175   bool TraverseDecl(Decl *D) {
176     if (isNodeExcluded(Tree.AST.getSourceManager(), D))
177       return true;
178     ++Count;
179     RecursiveASTVisitor<NodeCountVisitor>::TraverseDecl(D);
180     return true;
181   }
182   bool TraverseStmt(Stmt *S) {
183     if (isNodeExcluded(Tree.AST.getSourceManager(), S))
184       return true;
185     ++Count;
186     RecursiveASTVisitor<NodeCountVisitor>::TraverseStmt(S);
187     return true;
188   }
189   bool TraverseType(QualType T) { return true; }
190 };
191 } // end anonymous namespace
192 
193 namespace {
194 // Sets Height, Parent and Children for each node.
195 struct PreorderVisitor : public RecursiveASTVisitor<PreorderVisitor> {
196   int Id = 0, Depth = 0;
197   NodeId Parent;
198   SyntaxTree::Impl &Tree;
199 
200   PreorderVisitor(SyntaxTree::Impl &Tree) : Tree(Tree) {}
201 
202   template <class T> std::tuple<NodeId, NodeId> PreTraverse(T *ASTNode) {
203     NodeId MyId = Id;
204     Node &N = Tree.getMutableNode(MyId);
205     N.Parent = Parent;
206     N.Depth = Depth;
207     N.ASTNode = DynTypedNode::create(*ASTNode);
208     assert(!N.ASTNode.getNodeKind().isNone() &&
209            "Expected nodes to have a valid kind.");
210     if (Parent.isValid()) {
211       Node &P = Tree.getMutableNode(Parent);
212       P.Children.push_back(MyId);
213     }
214     Parent = MyId;
215     ++Id;
216     ++Depth;
217     return std::make_tuple(MyId, Tree.getNode(MyId).Parent);
218   }
219   void PostTraverse(std::tuple<NodeId, NodeId> State) {
220     NodeId MyId, PreviousParent;
221     std::tie(MyId, PreviousParent) = State;
222     assert(MyId.isValid() && "Expecting to only traverse valid nodes.");
223     Parent = PreviousParent;
224     --Depth;
225     Node &N = Tree.getMutableNode(MyId);
226     N.RightMostDescendant = Id - 1;
227     assert(N.RightMostDescendant >= 0 &&
228            N.RightMostDescendant < Tree.getSize() &&
229            "Rightmost descendant must be a valid tree node.");
230     if (N.isLeaf())
231       Tree.Leaves.push_back(MyId);
232     N.Height = 1;
233     for (NodeId Child : N.Children)
234       N.Height = std::max(N.Height, 1 + Tree.getNode(Child).Height);
235   }
236   bool TraverseDecl(Decl *D) {
237     if (isNodeExcluded(Tree.AST.getSourceManager(), D))
238       return true;
239     auto SavedState = PreTraverse(D);
240     RecursiveASTVisitor<PreorderVisitor>::TraverseDecl(D);
241     PostTraverse(SavedState);
242     return true;
243   }
244   bool TraverseStmt(Stmt *S) {
245     if (isNodeExcluded(Tree.AST.getSourceManager(), S))
246       return true;
247     auto SavedState = PreTraverse(S);
248     RecursiveASTVisitor<PreorderVisitor>::TraverseStmt(S);
249     PostTraverse(SavedState);
250     return true;
251   }
252   bool TraverseType(QualType T) { return true; }
253 };
254 } // end anonymous namespace
255 
256 SyntaxTree::Impl::Impl(SyntaxTree *Parent, const ASTContext &AST)
257     : Impl(Parent, AST.getTranslationUnitDecl(), AST) {}
258 
259 SyntaxTree::Impl::Impl(SyntaxTree *Parent, Decl *N, const ASTContext &AST)
260     : Parent(Parent), AST(AST) {
261   NodeCountVisitor NodeCounter(*this);
262   NodeCounter.TraverseDecl(N);
263   Nodes.resize(NodeCounter.Count);
264   PreorderVisitor PreorderWalker(*this);
265   PreorderWalker.TraverseDecl(N);
266   initTree();
267 }
268 
269 SyntaxTree::Impl::Impl(SyntaxTree *Parent, Stmt *N, const ASTContext &AST)
270     : Parent(Parent), AST(AST) {
271   NodeCountVisitor NodeCounter(*this);
272   NodeCounter.TraverseStmt(N);
273   Nodes.resize(NodeCounter.Count);
274   PreorderVisitor PreorderWalker(*this);
275   PreorderWalker.TraverseStmt(N);
276   initTree();
277 }
278 
279 static std::vector<NodeId> getSubtreePostorder(const SyntaxTree::Impl &Tree,
280                                                NodeId Root) {
281   std::vector<NodeId> Postorder;
282   std::function<void(NodeId)> Traverse = [&](NodeId Id) {
283     const Node &N = Tree.getNode(Id);
284     for (NodeId Child : N.Children)
285       Traverse(Child);
286     Postorder.push_back(Id);
287   };
288   Traverse(Root);
289   return Postorder;
290 }
291 
292 static std::vector<NodeId> getSubtreeBfs(const SyntaxTree::Impl &Tree,
293                                          NodeId Root) {
294   std::vector<NodeId> Ids;
295   size_t Expanded = 0;
296   Ids.push_back(Root);
297   while (Expanded < Ids.size())
298     for (NodeId Child : Tree.getNode(Ids[Expanded++]).Children)
299       Ids.push_back(Child);
300   return Ids;
301 }
302 
303 void SyntaxTree::Impl::initTree() {
304   setLeftMostDescendants();
305   int PostorderId = 0;
306   PostorderIds.resize(getSize());
307   std::function<void(NodeId)> PostorderTraverse = [&](NodeId Id) {
308     for (NodeId Child : getNode(Id).Children)
309       PostorderTraverse(Child);
310     PostorderIds[Id] = PostorderId;
311     ++PostorderId;
312   };
313   PostorderTraverse(getRootId());
314   NodesBfs = getSubtreeBfs(*this, getRootId());
315 }
316 
317 void SyntaxTree::Impl::setLeftMostDescendants() {
318   for (NodeId Leaf : Leaves) {
319     getMutableNode(Leaf).LeftMostDescendant = Leaf;
320     NodeId Parent, Cur = Leaf;
321     while ((Parent = getNode(Cur).Parent).isValid() &&
322            getNode(Parent).Children[0] == Cur) {
323       Cur = Parent;
324       getMutableNode(Cur).LeftMostDescendant = Leaf;
325     }
326   }
327 }
328 
329 int SyntaxTree::Impl::getNumberOfDescendants(NodeId Id) const {
330   return getNode(Id).RightMostDescendant - Id + 1;
331 }
332 
333 bool SyntaxTree::Impl::isInSubtree(NodeId Id, NodeId SubtreeRoot) const {
334   return Id >= SubtreeRoot && Id <= getNode(SubtreeRoot).RightMostDescendant;
335 }
336 
337 int SyntaxTree::Impl::findPositionInParent(NodeId Id, bool Shifted) const {
338   NodeId Parent = getNode(Id).Parent;
339   if (Parent.isInvalid())
340     return 0;
341   const auto &Siblings = getNode(Parent).Children;
342   int Position = 0;
343   for (size_t I = 0, E = Siblings.size(); I < E; ++I) {
344     if (Shifted)
345       Position += getNode(Siblings[I]).Shift;
346     if (Siblings[I] == Id) {
347       Position += I;
348       return Position;
349     }
350   }
351   llvm_unreachable("Node not found in parent's children.");
352 }
353 
354 std::string SyntaxTree::Impl::getNodeValue(NodeId Id) const {
355   return getNodeValue(getNode(Id));
356 }
357 
358 std::string SyntaxTree::Impl::getNodeValue(const Node &N) const {
359   const DynTypedNode &DTN = N.ASTNode;
360   if (auto *X = DTN.get<BinaryOperator>())
361     return X->getOpcodeStr();
362   if (auto *X = DTN.get<AccessSpecDecl>()) {
363     CharSourceRange Range(X->getSourceRange(), false);
364     return Lexer::getSourceText(Range, AST.getSourceManager(),
365                                 AST.getLangOpts());
366   }
367   if (auto *X = DTN.get<IntegerLiteral>()) {
368     SmallString<256> Str;
369     X->getValue().toString(Str, /*Radix=*/10, /*Signed=*/false);
370     return Str.str();
371   }
372   if (auto *X = DTN.get<StringLiteral>())
373     return X->getString();
374   if (auto *X = DTN.get<ValueDecl>())
375     return X->getNameAsString() + "(" + X->getType().getAsString() + ")";
376   if (DTN.get<DeclStmt>() || DTN.get<TranslationUnitDecl>())
377     return "";
378   std::string Value;
379   if (auto *X = DTN.get<DeclRefExpr>()) {
380     if (X->hasQualifier()) {
381       llvm::raw_string_ostream OS(Value);
382       PrintingPolicy PP(AST.getLangOpts());
383       X->getQualifier()->print(OS, PP);
384     }
385     Value += X->getDecl()->getNameAsString();
386     return Value;
387   }
388   if (auto *X = DTN.get<NamedDecl>())
389     Value += X->getNameAsString() + ";";
390   if (auto *X = DTN.get<TypedefNameDecl>())
391     return Value + X->getUnderlyingType().getAsString() + ";";
392   if (DTN.get<NamespaceDecl>())
393     return Value;
394   if (auto *X = DTN.get<TypeDecl>())
395     if (X->getTypeForDecl())
396       Value +=
397           X->getTypeForDecl()->getCanonicalTypeInternal().getAsString() + ";";
398   if (DTN.get<Decl>())
399     return Value;
400   if (DTN.get<Stmt>())
401     return "";
402   llvm_unreachable("Fatal: unhandled AST node.\n");
403 }
404 
405 /// Identifies a node in a subtree by its postorder offset, starting at 1.
406 struct SNodeId {
407   int Id = 0;
408 
409   explicit SNodeId(int Id) : Id(Id) {}
410   explicit SNodeId() = default;
411 
412   operator int() const { return Id; }
413   SNodeId &operator++() { return ++Id, *this; }
414   SNodeId &operator--() { return --Id, *this; }
415   SNodeId operator+(int Other) const { return SNodeId(Id + Other); }
416 };
417 
418 class Subtree {
419 private:
420   /// The parent tree.
421   const SyntaxTree::Impl &Tree;
422   /// Maps SNodeIds to original ids.
423   std::vector<NodeId> RootIds;
424   /// Maps subtree nodes to their leftmost descendants wtihin the subtree.
425   std::vector<SNodeId> LeftMostDescendants;
426 
427 public:
428   std::vector<SNodeId> KeyRoots;
429 
430   Subtree(const SyntaxTree::Impl &Tree, NodeId SubtreeRoot) : Tree(Tree) {
431     RootIds = getSubtreePostorder(Tree, SubtreeRoot);
432     int NumLeaves = setLeftMostDescendants();
433     computeKeyRoots(NumLeaves);
434   }
435   int getSize() const { return RootIds.size(); }
436   NodeId getIdInRoot(SNodeId Id) const {
437     assert(Id > 0 && Id <= getSize() && "Invalid subtree node index.");
438     return RootIds[Id - 1];
439   }
440   const Node &getNode(SNodeId Id) const {
441     return Tree.getNode(getIdInRoot(Id));
442   }
443   SNodeId getLeftMostDescendant(SNodeId Id) const {
444     assert(Id > 0 && Id <= getSize() && "Invalid subtree node index.");
445     return LeftMostDescendants[Id - 1];
446   }
447   /// Returns the postorder index of the leftmost descendant in the subtree.
448   NodeId getPostorderOffset() const {
449     return Tree.PostorderIds[getIdInRoot(SNodeId(1))];
450   }
451   std::string getNodeValue(SNodeId Id) const {
452     return Tree.getNodeValue(getIdInRoot(Id));
453   }
454 
455 private:
456   /// Returns the number of leafs in the subtree.
457   int setLeftMostDescendants() {
458     int NumLeaves = 0;
459     LeftMostDescendants.resize(getSize());
460     for (int I = 0; I < getSize(); ++I) {
461       SNodeId SI(I + 1);
462       const Node &N = getNode(SI);
463       NumLeaves += N.isLeaf();
464       assert(I == Tree.PostorderIds[getIdInRoot(SI)] - getPostorderOffset() &&
465              "Postorder traversal in subtree should correspond to traversal in "
466              "the root tree by a constant offset.");
467       LeftMostDescendants[I] = SNodeId(Tree.PostorderIds[N.LeftMostDescendant] -
468                                        getPostorderOffset());
469     }
470     return NumLeaves;
471   }
472   void computeKeyRoots(int Leaves) {
473     KeyRoots.resize(Leaves);
474     std::unordered_set<int> Visited;
475     int K = Leaves - 1;
476     for (SNodeId I(getSize()); I > 0; --I) {
477       SNodeId LeftDesc = getLeftMostDescendant(I);
478       if (Visited.count(LeftDesc))
479         continue;
480       assert(K >= 0 && "K should be non-negative");
481       KeyRoots[K] = I;
482       Visited.insert(LeftDesc);
483       --K;
484     }
485   }
486 };
487 
488 /// Implementation of Zhang and Shasha's Algorithm for tree edit distance.
489 /// Computes an optimal mapping between two trees using only insertion,
490 /// deletion and update as edit actions (similar to the Levenshtein distance).
491 class ZhangShashaMatcher {
492   const ASTDiff::Impl &DiffImpl;
493   Subtree S1;
494   Subtree S2;
495   std::unique_ptr<std::unique_ptr<double[]>[]> TreeDist, ForestDist;
496 
497 public:
498   ZhangShashaMatcher(const ASTDiff::Impl &DiffImpl, const SyntaxTree::Impl &T1,
499                      const SyntaxTree::Impl &T2, NodeId Id1, NodeId Id2)
500       : DiffImpl(DiffImpl), S1(T1, Id1), S2(T2, Id2) {
501     TreeDist = llvm::make_unique<std::unique_ptr<double[]>[]>(
502         size_t(S1.getSize()) + 1);
503     ForestDist = llvm::make_unique<std::unique_ptr<double[]>[]>(
504         size_t(S1.getSize()) + 1);
505     for (int I = 0, E = S1.getSize() + 1; I < E; ++I) {
506       TreeDist[I] = llvm::make_unique<double[]>(size_t(S2.getSize()) + 1);
507       ForestDist[I] = llvm::make_unique<double[]>(size_t(S2.getSize()) + 1);
508     }
509   }
510 
511   std::vector<std::pair<NodeId, NodeId>> getMatchingNodes() {
512     std::vector<std::pair<NodeId, NodeId>> Matches;
513     std::vector<std::pair<SNodeId, SNodeId>> TreePairs;
514 
515     computeTreeDist();
516 
517     bool RootNodePair = true;
518 
519     TreePairs.emplace_back(SNodeId(S1.getSize()), SNodeId(S2.getSize()));
520 
521     while (!TreePairs.empty()) {
522       SNodeId LastRow, LastCol, FirstRow, FirstCol, Row, Col;
523       std::tie(LastRow, LastCol) = TreePairs.back();
524       TreePairs.pop_back();
525 
526       if (!RootNodePair) {
527         computeForestDist(LastRow, LastCol);
528       }
529 
530       RootNodePair = false;
531 
532       FirstRow = S1.getLeftMostDescendant(LastRow);
533       FirstCol = S2.getLeftMostDescendant(LastCol);
534 
535       Row = LastRow;
536       Col = LastCol;
537 
538       while (Row > FirstRow || Col > FirstCol) {
539         if (Row > FirstRow &&
540             ForestDist[Row - 1][Col] + 1 == ForestDist[Row][Col]) {
541           --Row;
542         } else if (Col > FirstCol &&
543                    ForestDist[Row][Col - 1] + 1 == ForestDist[Row][Col]) {
544           --Col;
545         } else {
546           SNodeId LMD1 = S1.getLeftMostDescendant(Row);
547           SNodeId LMD2 = S2.getLeftMostDescendant(Col);
548           if (LMD1 == S1.getLeftMostDescendant(LastRow) &&
549               LMD2 == S2.getLeftMostDescendant(LastCol)) {
550             NodeId Id1 = S1.getIdInRoot(Row);
551             NodeId Id2 = S2.getIdInRoot(Col);
552             assert(DiffImpl.isMatchingPossible(Id1, Id2) &&
553                    "These nodes must not be matched.");
554             Matches.emplace_back(Id1, Id2);
555             --Row;
556             --Col;
557           } else {
558             TreePairs.emplace_back(Row, Col);
559             Row = LMD1;
560             Col = LMD2;
561           }
562         }
563       }
564     }
565     return Matches;
566   }
567 
568 private:
569   /// We use a simple cost model for edit actions, which seems good enough.
570   /// Simple cost model for edit actions. This seems to make the matching
571   /// algorithm perform reasonably well.
572   /// The values range between 0 and 1, or infinity if this edit action should
573   /// always be avoided.
574   static constexpr double DeletionCost = 1;
575   static constexpr double InsertionCost = 1;
576 
577   double getUpdateCost(SNodeId Id1, SNodeId Id2) {
578     if (!DiffImpl.isMatchingPossible(S1.getIdInRoot(Id1), S2.getIdInRoot(Id2)))
579       return std::numeric_limits<double>::max();
580     return S1.getNodeValue(Id1) != S2.getNodeValue(Id2);
581   }
582 
583   void computeTreeDist() {
584     for (SNodeId Id1 : S1.KeyRoots)
585       for (SNodeId Id2 : S2.KeyRoots)
586         computeForestDist(Id1, Id2);
587   }
588 
589   void computeForestDist(SNodeId Id1, SNodeId Id2) {
590     assert(Id1 > 0 && Id2 > 0 && "Expecting offsets greater than 0.");
591     SNodeId LMD1 = S1.getLeftMostDescendant(Id1);
592     SNodeId LMD2 = S2.getLeftMostDescendant(Id2);
593 
594     ForestDist[LMD1][LMD2] = 0;
595     for (SNodeId D1 = LMD1 + 1; D1 <= Id1; ++D1) {
596       ForestDist[D1][LMD2] = ForestDist[D1 - 1][LMD2] + DeletionCost;
597       for (SNodeId D2 = LMD2 + 1; D2 <= Id2; ++D2) {
598         ForestDist[LMD1][D2] = ForestDist[LMD1][D2 - 1] + InsertionCost;
599         SNodeId DLMD1 = S1.getLeftMostDescendant(D1);
600         SNodeId DLMD2 = S2.getLeftMostDescendant(D2);
601         if (DLMD1 == LMD1 && DLMD2 == LMD2) {
602           double UpdateCost = getUpdateCost(D1, D2);
603           ForestDist[D1][D2] =
604               std::min({ForestDist[D1 - 1][D2] + DeletionCost,
605                         ForestDist[D1][D2 - 1] + InsertionCost,
606                         ForestDist[D1 - 1][D2 - 1] + UpdateCost});
607           TreeDist[D1][D2] = ForestDist[D1][D2];
608         } else {
609           ForestDist[D1][D2] =
610               std::min({ForestDist[D1 - 1][D2] + DeletionCost,
611                         ForestDist[D1][D2 - 1] + InsertionCost,
612                         ForestDist[DLMD1][DLMD2] + TreeDist[D1][D2]});
613         }
614       }
615     }
616   }
617 };
618 
619 ast_type_traits::ASTNodeKind Node::getType() const {
620   return ASTNode.getNodeKind();
621 }
622 
623 StringRef Node::getTypeLabel() const { return getType().asStringRef(); }
624 
625 namespace {
626 // Compares nodes by their depth.
627 struct HeightLess {
628   const SyntaxTree::Impl &Tree;
629   HeightLess(const SyntaxTree::Impl &Tree) : Tree(Tree) {}
630   bool operator()(NodeId Id1, NodeId Id2) const {
631     return Tree.getNode(Id1).Height < Tree.getNode(Id2).Height;
632   }
633 };
634 } // end anonymous namespace
635 
636 namespace {
637 // Priority queue for nodes, sorted descendingly by their height.
638 class PriorityList {
639   const SyntaxTree::Impl &Tree;
640   HeightLess Cmp;
641   std::vector<NodeId> Container;
642   PriorityQueue<NodeId, std::vector<NodeId>, HeightLess> List;
643 
644 public:
645   PriorityList(const SyntaxTree::Impl &Tree)
646       : Tree(Tree), Cmp(Tree), List(Cmp, Container) {}
647 
648   void push(NodeId id) { List.push(id); }
649 
650   std::vector<NodeId> pop() {
651     int Max = peekMax();
652     std::vector<NodeId> Result;
653     if (Max == 0)
654       return Result;
655     while (peekMax() == Max) {
656       Result.push_back(List.top());
657       List.pop();
658     }
659     // TODO this is here to get a stable output, not a good heuristic
660     std::sort(Result.begin(), Result.end());
661     return Result;
662   }
663   int peekMax() const {
664     if (List.empty())
665       return 0;
666     return Tree.getNode(List.top()).Height;
667   }
668   void open(NodeId Id) {
669     for (NodeId Child : Tree.getNode(Id).Children)
670       push(Child);
671   }
672 };
673 } // end anonymous namespace
674 
675 bool ASTDiff::Impl::identical(NodeId Id1, NodeId Id2) const {
676   const Node &N1 = T1.getNode(Id1);
677   const Node &N2 = T2.getNode(Id2);
678   if (N1.Children.size() != N2.Children.size() ||
679       !isMatchingPossible(Id1, Id2) ||
680       T1.getNodeValue(Id1) != T2.getNodeValue(Id2))
681     return false;
682   for (size_t Id = 0, E = N1.Children.size(); Id < E; ++Id)
683     if (!identical(N1.Children[Id], N2.Children[Id]))
684       return false;
685   return true;
686 }
687 
688 bool ASTDiff::Impl::isMatchingPossible(NodeId Id1, NodeId Id2) const {
689   return Options.isMatchingAllowed(T1.getNode(Id1), T2.getNode(Id2));
690 }
691 
692 bool ASTDiff::Impl::haveSameParents(const Mapping &M, NodeId Id1,
693                                     NodeId Id2) const {
694   NodeId P1 = T1.getNode(Id1).Parent;
695   NodeId P2 = T2.getNode(Id2).Parent;
696   return (P1.isInvalid() && P2.isInvalid()) ||
697          (P1.isValid() && P2.isValid() && M.getDst(P1) == P2);
698 }
699 
700 void ASTDiff::Impl::addOptimalMapping(Mapping &M, NodeId Id1,
701                                       NodeId Id2) const {
702   if (std::max(T1.getNumberOfDescendants(Id1),
703                T2.getNumberOfDescendants(Id2)) >= Options.MaxSize)
704     return;
705   ZhangShashaMatcher Matcher(*this, T1, T2, Id1, Id2);
706   std::vector<std::pair<NodeId, NodeId>> R = Matcher.getMatchingNodes();
707   for (const auto Tuple : R) {
708     NodeId Src = Tuple.first;
709     NodeId Dst = Tuple.second;
710     if (!M.hasSrc(Src) && !M.hasDst(Dst))
711       M.link(Src, Dst);
712   }
713 }
714 
715 double ASTDiff::Impl::getSimilarity(const Mapping &M, NodeId Id1,
716                                     NodeId Id2) const {
717   if (Id1.isInvalid() || Id2.isInvalid())
718     return 0.0;
719   int CommonDescendants = 0;
720   const Node &N1 = T1.getNode(Id1);
721   for (NodeId Id = Id1 + 1; Id <= N1.RightMostDescendant; ++Id)
722     CommonDescendants += int(T2.isInSubtree(M.getDst(Id), Id2));
723   return 2.0 * CommonDescendants /
724          (T1.getNumberOfDescendants(Id1) + T2.getNumberOfDescendants(Id2));
725 }
726 
727 NodeId ASTDiff::Impl::findCandidate(const Mapping &M, NodeId Id1) const {
728   NodeId Candidate;
729   double HighestSimilarity = 0.0;
730   for (NodeId Id2 : T2) {
731     if (!isMatchingPossible(Id1, Id2))
732       continue;
733     if (M.hasDst(Id2))
734       continue;
735     double Similarity = getSimilarity(M, Id1, Id2);
736     if (Similarity >= Options.MinSimilarity && Similarity > HighestSimilarity) {
737       HighestSimilarity = Similarity;
738       Candidate = Id2;
739     }
740   }
741   return Candidate;
742 }
743 
744 void ASTDiff::Impl::matchBottomUp(Mapping &M) const {
745   std::vector<NodeId> Postorder = getSubtreePostorder(T1, T1.getRootId());
746   for (NodeId Id1 : Postorder) {
747     if (Id1 == T1.getRootId() && !M.hasSrc(T1.getRootId()) &&
748         !M.hasDst(T2.getRootId())) {
749       if (isMatchingPossible(T1.getRootId(), T2.getRootId())) {
750         M.link(T1.getRootId(), T2.getRootId());
751         addOptimalMapping(M, T1.getRootId(), T2.getRootId());
752       }
753       break;
754     }
755     const Node &N1 = T1.getNode(Id1);
756     bool Matched = M.hasSrc(Id1);
757     bool MatchedChildren =
758         std::any_of(N1.Children.begin(), N1.Children.end(),
759                     [&](NodeId Child) { return M.hasSrc(Child); });
760     if (Matched || !MatchedChildren)
761       continue;
762     NodeId Id2 = findCandidate(M, Id1);
763     if (Id2.isValid()) {
764       M.link(Id1, Id2);
765       addOptimalMapping(M, Id1, Id2);
766     }
767   }
768 }
769 
770 Mapping ASTDiff::Impl::matchTopDown() const {
771   PriorityList L1(T1);
772   PriorityList L2(T2);
773 
774   Mapping M(T1.getSize() + T2.getSize());
775 
776   L1.push(T1.getRootId());
777   L2.push(T2.getRootId());
778 
779   int Max1, Max2;
780   while (std::min(Max1 = L1.peekMax(), Max2 = L2.peekMax()) >
781          Options.MinHeight) {
782     if (Max1 > Max2) {
783       for (NodeId Id : L1.pop())
784         L1.open(Id);
785       continue;
786     }
787     if (Max2 > Max1) {
788       for (NodeId Id : L2.pop())
789         L2.open(Id);
790       continue;
791     }
792     std::vector<NodeId> H1, H2;
793     H1 = L1.pop();
794     H2 = L2.pop();
795     for (NodeId Id1 : H1) {
796       for (NodeId Id2 : H2) {
797         if (identical(Id1, Id2) && !M.hasSrc(Id1) && !M.hasDst(Id2)) {
798           for (int I = 0, E = T1.getNumberOfDescendants(Id1); I < E; ++I)
799             M.link(Id1 + I, Id2 + I);
800         }
801       }
802     }
803     for (NodeId Id1 : H1) {
804       if (!M.hasSrc(Id1))
805         L1.open(Id1);
806     }
807     for (NodeId Id2 : H2) {
808       if (!M.hasDst(Id2))
809         L2.open(Id2);
810     }
811   }
812   return M;
813 }
814 
815 ASTDiff::Impl::Impl(SyntaxTree::Impl &T1, SyntaxTree::Impl &T2,
816                     const ComparisonOptions &Options)
817     : T1(T1), T2(T2), Options(Options) {
818   computeMapping();
819   computeChangeKinds(TheMapping);
820 }
821 
822 void ASTDiff::Impl::computeMapping() {
823   TheMapping = matchTopDown();
824   matchBottomUp(TheMapping);
825 }
826 
827 void ASTDiff::Impl::computeChangeKinds(Mapping &M) {
828   for (NodeId Id1 : T1) {
829     if (!M.hasSrc(Id1)) {
830       T1.getMutableNode(Id1).Change = Delete;
831       T1.getMutableNode(Id1).Shift -= 1;
832     }
833   }
834   for (NodeId Id2 : T2) {
835     if (!M.hasDst(Id2)) {
836       T2.getMutableNode(Id2).Change = Insert;
837       T2.getMutableNode(Id2).Shift -= 1;
838     }
839   }
840   for (NodeId Id1 : T1.NodesBfs) {
841     NodeId Id2 = M.getDst(Id1);
842     if (Id2.isInvalid())
843       continue;
844     if (!haveSameParents(M, Id1, Id2) ||
845         T1.findPositionInParent(Id1, true) !=
846             T2.findPositionInParent(Id2, true)) {
847       T1.getMutableNode(Id1).Shift -= 1;
848       T2.getMutableNode(Id2).Shift -= 1;
849     }
850   }
851   for (NodeId Id2 : T2.NodesBfs) {
852     NodeId Id1 = M.getSrc(Id2);
853     if (Id1.isInvalid())
854       continue;
855     Node &N1 = T1.getMutableNode(Id1);
856     Node &N2 = T2.getMutableNode(Id2);
857     if (Id1.isInvalid())
858       continue;
859     if (!haveSameParents(M, Id1, Id2) ||
860         T1.findPositionInParent(Id1, true) !=
861             T2.findPositionInParent(Id2, true)) {
862       N1.Change = N2.Change = Move;
863     }
864     if (T1.getNodeValue(Id1) != T2.getNodeValue(Id2)) {
865       N1.Change = N2.Change = (N1.Change == Move ? UpdateMove : Update);
866     }
867   }
868 }
869 
870 ASTDiff::ASTDiff(SyntaxTree &T1, SyntaxTree &T2,
871                  const ComparisonOptions &Options)
872     : DiffImpl(llvm::make_unique<Impl>(*T1.TreeImpl, *T2.TreeImpl, Options)) {}
873 
874 ASTDiff::~ASTDiff() = default;
875 
876 NodeId ASTDiff::getMapped(const SyntaxTree &SourceTree, NodeId Id) const {
877   return DiffImpl->getMapped(SourceTree.TreeImpl, Id);
878 }
879 
880 SyntaxTree::SyntaxTree(const ASTContext &AST)
881     : TreeImpl(llvm::make_unique<SyntaxTree::Impl>(
882           this, AST.getTranslationUnitDecl(), AST)) {}
883 
884 SyntaxTree::~SyntaxTree() = default;
885 
886 const ASTContext &SyntaxTree::getASTContext() const { return TreeImpl->AST; }
887 
888 const Node &SyntaxTree::getNode(NodeId Id) const {
889   return TreeImpl->getNode(Id);
890 }
891 
892 int SyntaxTree::getSize() const { return TreeImpl->getSize(); }
893 NodeId SyntaxTree::getRootId() const { return TreeImpl->getRootId(); }
894 SyntaxTree::PreorderIterator SyntaxTree::begin() const {
895   return TreeImpl->begin();
896 }
897 SyntaxTree::PreorderIterator SyntaxTree::end() const { return TreeImpl->end(); }
898 
899 int SyntaxTree::findPositionInParent(NodeId Id) const {
900   return TreeImpl->findPositionInParent(Id);
901 }
902 
903 std::pair<unsigned, unsigned>
904 SyntaxTree::getSourceRangeOffsets(const Node &N) const {
905   const SourceManager &SrcMgr = TreeImpl->AST.getSourceManager();
906   SourceRange Range = N.ASTNode.getSourceRange();
907   SourceLocation BeginLoc = Range.getBegin();
908   SourceLocation EndLoc = Lexer::getLocForEndOfToken(
909       Range.getEnd(), /*Offset=*/0, SrcMgr, TreeImpl->AST.getLangOpts());
910   if (auto *ThisExpr = N.ASTNode.get<CXXThisExpr>()) {
911     if (ThisExpr->isImplicit())
912       EndLoc = BeginLoc;
913   }
914   unsigned Begin = SrcMgr.getFileOffset(SrcMgr.getExpansionLoc(BeginLoc));
915   unsigned End = SrcMgr.getFileOffset(SrcMgr.getExpansionLoc(EndLoc));
916   return {Begin, End};
917 }
918 
919 std::string SyntaxTree::getNodeValue(NodeId Id) const {
920   return TreeImpl->getNodeValue(Id);
921 }
922 
923 std::string SyntaxTree::getNodeValue(const Node &N) const {
924   return TreeImpl->getNodeValue(N);
925 }
926 
927 } // end namespace diff
928 } // end namespace clang
929