Cosmetic change.
[oota-llvm.git] / include / llvm / ADT / Trie.h
1 //===- llvm/ADT/Trie.h ---- Generic trie structure --------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by Anton Korobeynikov and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This class defines a generic trie structure. The trie structure
11 // is immutable after creation, but the payload contained within it is not.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #ifndef LLVM_ADT_TRIE_H
16 #define LLVM_ADT_TRIE_H
17
18 #include <map>
19 #include <vector>
20
21 namespace llvm {
22
23 // FIXME:
24 // - Labels are usually small, maybe it's better to use SmallString
25 // - Should we use char* during construction?
26 // - Should we templatize Empty with traits-like interface?
27 // - GraphTraits interface
28
29 template<class Payload>
30 class Trie {
31   class Node {
32     friend class Trie;
33
34     typedef enum {
35       Same           = -3,
36       StringIsPrefix = -2,
37       LabelIsPrefix  = -1,
38       DontMatch      = 0,
39       HaveCommonPart
40     } QueryResult;
41     typedef std::vector<Node*> NodeVector;
42     typedef typename std::vector<Node*>::iterator NodeVectorIter;
43
44     struct NodeCmp {
45       bool operator() (Node* N1, Node* N2) {
46         return (N1->Label[0] < N2->Label[0]);
47       }
48       bool operator() (Node* N, char Id) {
49         return (N->Label[0] < Id);
50       }
51     };
52
53     std::string Label;
54     Payload Data;
55     NodeVector Children;
56   public:
57     inline explicit Node(const Payload& data, const std::string& label = ""):
58         Label(label), Data(data) { }
59
60     inline Node(const Node& n) {
61       Data = n.Data;
62       Children = n.Children;
63       Label = n.Label;
64     }
65     inline Node& operator=(const Node& n) {
66       if (&n != this) {
67         Data = n.Data;
68         Children = n.Children;
69         Label = n.Label;
70       }
71
72       return *this;
73     }
74
75     inline bool isLeaf() const { return Children.empty(); }
76
77     inline const Payload& getData() const { return Data; }
78     inline void setData(const Payload& data) { Data = data; }
79
80     inline void setLabel(const std::string& label) { Label = label; }
81     inline const std::string& getLabel() const { return Label; }
82
83 #if 0
84     inline void dump() {
85       std::cerr << "Node: " << this << "\n"
86                 << "Label: " << Label << "\n"
87                 << "Children:\n";
88
89       for (NodeVectorIter I = Children.begin(), E = Children.end(); I != E; ++I)
90         std::cerr << (*I)->Label << "\n";
91     }
92 #endif
93
94     inline void addEdge(Node* N) {
95       if (Children.empty())
96         Children.push_back(N);
97       else {
98         NodeVectorIter I = std::lower_bound(Children.begin(), Children.end(),
99                                             N, NodeCmp());
100         // FIXME: no dups are allowed
101         Children.insert(I, N);
102       }
103     }
104
105     inline Node* getEdge(char Id) {
106       Node* fNode = NULL;
107       NodeVectorIter I = std::lower_bound(Children.begin(), Children.end(),
108                                           Id, NodeCmp());
109       if (I != Children.end() && (*I)->Label[0] == Id)
110         fNode = *I;
111
112       return fNode;
113     }
114
115     inline void setEdge(Node* N) {
116       char Id = N->Label[0];
117       NodeVectorIter I = std::lower_bound(Children.begin(), Children.end(),
118                                           Id, NodeCmp());
119       assert(I != Children.end() && "Node does not exists!");
120       *I = N;
121     }
122
123     QueryResult query(const std::string& s) const {
124       unsigned i, l;
125       unsigned l1 = s.length();
126       unsigned l2 = Label.length();
127
128       // Find the length of common part
129       l = std::min(l1, l2);
130       i = 0;
131       while ((i < l) && (s[i] == Label[i]))
132         ++i;
133
134       if (i == l) { // One is prefix of another, find who is who
135         if (l1 == l2)
136           return Same;
137         else if (i == l1)
138           return StringIsPrefix;
139         else
140           return LabelIsPrefix;
141       } else // s and Label have common (possible empty) part, return its length
142         return (QueryResult)i;
143     }
144   };
145
146   std::vector<Node*> Nodes;
147   Payload Empty;
148
149   inline Node* getRoot() const { return Nodes[0]; }
150
151   inline Node* addNode(const Payload& data, const std::string label = "") {
152     Node* N = new Node(data, label);
153     Nodes.push_back(N);
154     return N;
155   }
156
157   inline Node* splitEdge(Node* N, char Id, size_t index) {
158     Node* eNode = N->getEdge(Id);
159     assert(eNode && "Node doesn't exist");
160
161     const std::string &l = eNode->Label;
162     assert(index > 0 && index < l.length() && "Trying to split too far!");
163     std::string l1 = l.substr(0, index);
164     std::string l2 = l.substr(index);
165
166     Node* nNode = addNode(Empty, l1);
167     N->setEdge(nNode);
168
169     eNode->Label = l2;
170     nNode->addEdge(eNode);
171
172     return nNode;
173   }
174
175 public:
176   inline explicit Trie(const Payload& empty):Empty(empty) {
177     addNode(Empty);
178   }
179   inline ~Trie() {
180     for (unsigned i = 0, e = Nodes.size(); i != e; ++i)
181       delete Nodes[i];
182   }
183
184   bool addString(const std::string& s, const Payload& data) {
185     Node* cNode = getRoot();
186     Node* tNode = NULL;
187     std::string s1(s);
188
189     while (tNode == NULL) {
190       char Id = s1[0];
191       if (Node* nNode = cNode->getEdge(Id)) {
192         typename Node::QueryResult r = nNode->query(s1);
193
194         switch (r) {
195         case Node::Same:
196         case Node::StringIsPrefix:
197           // Currently we don't allow to have two strings in the trie one
198           // being a prefix of another. This should be fixed.
199           assert(0 && "FIXME!");
200           return false;
201         case Node::DontMatch:
202           assert(0 && "Impossible!");
203           return false;
204         case Node::LabelIsPrefix:
205           s1 = s1.substr(nNode->getLabel().length());
206           cNode = nNode;
207           break;
208         default:
209          nNode = splitEdge(cNode, Id, r);
210          tNode = addNode(data, s1.substr(r));
211          nNode->addEdge(tNode);
212        }
213       } else {
214         tNode = addNode(data, s1);
215         cNode->addEdge(tNode);
216       }
217     }
218
219     return true;
220   }
221
222   const Payload& lookup(const std::string& s) const {
223     Node* cNode = getRoot();
224     Node* tNode = NULL;
225     std::string s1(s);
226
227     while (tNode == NULL) {
228       char Id = s1[0];
229       if (Node* nNode = cNode->getEdge(Id)) {
230         typename Node::QueryResult r = nNode->query(s1);
231
232         switch (r) {
233         case Node::Same:
234           tNode = nNode;
235           break;
236         case Node::StringIsPrefix:
237           return Empty;
238         case Node::DontMatch:
239           assert(0 && "Impossible!");
240           return Empty;
241         case Node::LabelIsPrefix:
242           s1 = s1.substr(nNode->getLabel().length());
243           cNode = nNode;
244           break;
245         default:
246           return Empty;
247         }
248       } else
249         return Empty;
250     }
251
252     return tNode->getData();
253   }
254
255 };
256
257 }
258
259 #endif // LLVM_ADT_TRIE_H