Teach the built-in shell test runner in lit to handle '|&'-style pipes.
[oota-llvm.git] / utils / lit / lit / ShUtil.py
1 import itertools
2
3 import Util
4 from ShCommands import Command, Pipeline, Seq
5
6 class ShLexer:
7     def __init__(self, data, win32Escapes = False):
8         self.data = data
9         self.pos = 0
10         self.end = len(data)
11         self.win32Escapes = win32Escapes
12
13     def eat(self):
14         c = self.data[self.pos]
15         self.pos += 1
16         return c
17
18     def look(self):
19         return self.data[self.pos]
20
21     def maybe_eat(self, c):
22         """
23         maybe_eat(c) - Consume the character c if it is the next character,
24         returning True if a character was consumed. """
25         if self.data[self.pos] == c:
26             self.pos += 1
27             return True
28         return False
29
30     def lex_arg_fast(self, c):
31         # Get the leading whitespace free section.
32         chunk = self.data[self.pos - 1:].split(None, 1)[0]
33         
34         # If it has special characters, the fast path failed.
35         if ('|' in chunk or '&' in chunk or 
36             '<' in chunk or '>' in chunk or
37             "'" in chunk or '"' in chunk or
38             '\\' in chunk):
39             return None
40         
41         self.pos = self.pos - 1 + len(chunk)
42         return chunk
43         
44     def lex_arg_slow(self, c):
45         if c in "'\"":
46             str = self.lex_arg_quoted(c)
47         else:
48             str = c
49         while self.pos != self.end:
50             c = self.look()
51             if c.isspace() or c in "|&":
52                 break
53             elif c in '><':
54                 # This is an annoying case; we treat '2>' as a single token so
55                 # we don't have to track whitespace tokens.
56
57                 # If the parse string isn't an integer, do the usual thing.
58                 if not str.isdigit():
59                     break
60
61                 # Otherwise, lex the operator and convert to a redirection
62                 # token.
63                 num = int(str)
64                 tok = self.lex_one_token()
65                 assert isinstance(tok, tuple) and len(tok) == 1
66                 return (tok[0], num)                    
67             elif c == '"':
68                 self.eat()
69                 str += self.lex_arg_quoted('"')
70             elif c == "'":
71                 self.eat()
72                 str += self.lex_arg_quoted("'")
73             elif not self.win32Escapes and c == '\\':
74                 # Outside of a string, '\\' escapes everything.
75                 self.eat()
76                 if self.pos == self.end:
77                     Util.warning("escape at end of quoted argument in: %r" % 
78                                  self.data)
79                     return str
80                 str += self.eat()
81             else:
82                 str += self.eat()
83         return str
84
85     def lex_arg_quoted(self, delim):
86         str = ''
87         while self.pos != self.end:
88             c = self.eat()
89             if c == delim:
90                 return str
91             elif c == '\\' and delim == '"':
92                 # Inside a '"' quoted string, '\\' only escapes the quote
93                 # character and backslash, otherwise it is preserved.
94                 if self.pos == self.end:
95                     Util.warning("escape at end of quoted argument in: %r" % 
96                                  self.data)
97                     return str
98                 c = self.eat()
99                 if c == '"': # 
100                     str += '"'
101                 elif c == '\\':
102                     str += '\\'
103                 else:
104                     str += '\\' + c
105             else:
106                 str += c
107         Util.warning("missing quote character in %r" % self.data)
108         return str
109     
110     def lex_arg_checked(self, c):
111         pos = self.pos
112         res = self.lex_arg_fast(c)
113         end = self.pos
114
115         self.pos = pos
116         reference = self.lex_arg_slow(c)
117         if res is not None:
118             if res != reference:
119                 raise ValueError,"Fast path failure: %r != %r" % (res, reference)
120             if self.pos != end:
121                 raise ValueError,"Fast path failure: %r != %r" % (self.pos, end)
122         return reference
123         
124     def lex_arg(self, c):
125         return self.lex_arg_fast(c) or self.lex_arg_slow(c)
126         
127     def lex_one_token(self):
128         """
129         lex_one_token - Lex a single 'sh' token. """
130
131         c = self.eat()
132         if c in ';!':
133             return (c,)
134         if c == '|':
135             if self.maybe_eat('|'):
136                 return ('||',)
137             if self.maybe_eat('&'):
138                 return ('|&',)
139             return (c,)
140         if c == '&':
141             if self.maybe_eat('&'):
142                 return ('&&',)
143             if self.maybe_eat('>'): 
144                 return ('&>',)
145             return (c,)
146         if c == '>':
147             if self.maybe_eat('&'):
148                 return ('>&',)
149             if self.maybe_eat('>'):
150                 return ('>>',)
151             return (c,)
152         if c == '<':
153             if self.maybe_eat('&'):
154                 return ('<&',)
155             if self.maybe_eat('>'):
156                 return ('<<',)
157             return (c,)
158
159         return self.lex_arg(c)
160
161     def lex(self):
162         while self.pos != self.end:
163             if self.look().isspace():
164                 self.eat()
165             else:
166                 yield self.lex_one_token()
167
168 ###
169  
170 class ShParser:
171     def __init__(self, data, win32Escapes = False):
172         self.data = data
173         self.tokens = ShLexer(data, win32Escapes = win32Escapes).lex()
174     
175     def lex(self):
176         try:
177             return self.tokens.next()
178         except StopIteration:
179             return None
180     
181     def look(self):
182         next = self.lex()
183         if next is not None:
184             self.tokens = itertools.chain([next], self.tokens)
185         return next
186     
187     def parse_command(self):
188         tok = self.lex()
189         if not tok:
190             raise ValueError,"empty command!"
191         if isinstance(tok, tuple):
192             raise ValueError,"syntax error near unexpected token %r" % tok[0]
193         
194         args = [tok]
195         redirects = []
196         while 1:
197             tok = self.look()
198
199             # EOF?
200             if tok is None:
201                 break
202
203             # If this is an argument, just add it to the current command.
204             if isinstance(tok, str):
205                 args.append(self.lex())
206                 continue
207
208             # Otherwise see if it is a terminator.
209             assert isinstance(tok, tuple)
210             if tok[0] in ('|','|&',';','&','||','&&'):
211                 break
212             
213             # Otherwise it must be a redirection.
214             op = self.lex()
215             arg = self.lex()
216             if not arg:
217                 raise ValueError,"syntax error near token %r" % op[0]
218             redirects.append((op, arg))
219
220         return Command(args, redirects)
221
222     def parse_pipeline(self):
223         negate = False
224         if self.look() == ('!',):
225             self.lex()
226             negate = True
227
228         commands = [self.parse_command()]
229         while 1:
230             tok = self.look()
231             if tok == ('|',):
232               self.lex()
233               commands.append(self.parse_command())
234               continue
235             if tok == ('|&',):
236               self.lex()
237               commands[-1].redirects.insert(0, (('>&',2),'1'))
238               commands.append(self.parse_command())
239               continue
240             break
241         return Pipeline(commands, negate)
242             
243     def parse(self):
244         lhs = self.parse_pipeline()
245
246         while self.look():
247             operator = self.lex()
248             assert isinstance(operator, tuple) and len(operator) == 1
249
250             if not self.look():
251                 raise ValueError, "missing argument to operator %r" % operator[0]
252             
253             # FIXME: Operator precedence!!
254             lhs = Seq(lhs, operator[0], self.parse_pipeline())
255
256         return lhs
257
258 ###
259
260 import unittest
261
262 class TestShLexer(unittest.TestCase):
263     def lex(self, str, *args, **kwargs):
264         return list(ShLexer(str, *args, **kwargs).lex())
265
266     def test_basic(self):
267         self.assertEqual(self.lex('a|b>c&d<e'),
268                          ['a', ('|',), 'b', ('>',), 'c', ('&',), 'd', 
269                           ('<',), 'e'])
270
271     def test_redirection_tokens(self):
272         self.assertEqual(self.lex('a2>c'),
273                          ['a2', ('>',), 'c'])
274         self.assertEqual(self.lex('a 2>c'),
275                          ['a', ('>',2), 'c'])
276         
277     def test_quoting(self):
278         self.assertEqual(self.lex(""" 'a' """),
279                          ['a'])
280         self.assertEqual(self.lex(""" "hello\\"world" """),
281                          ['hello"world'])
282         self.assertEqual(self.lex(""" "hello\\'world" """),
283                          ["hello\\'world"])
284         self.assertEqual(self.lex(""" "hello\\\\world" """),
285                          ["hello\\world"])
286         self.assertEqual(self.lex(""" he"llo wo"rld """),
287                          ["hello world"])
288         self.assertEqual(self.lex(""" a\\ b a\\\\b """),
289                          ["a b", "a\\b"])
290         self.assertEqual(self.lex(""" "" "" """),
291                          ["", ""])
292         self.assertEqual(self.lex(""" a\\ b """, win32Escapes = True),
293                          ['a\\', 'b'])
294
295 class TestShParse(unittest.TestCase):
296     def parse(self, str):
297         return ShParser(str).parse()
298
299     def test_basic(self):
300         self.assertEqual(self.parse('echo hello'),
301                          Pipeline([Command(['echo', 'hello'], [])], False))
302         self.assertEqual(self.parse('echo ""'),
303                          Pipeline([Command(['echo', ''], [])], False))
304         self.assertEqual(self.parse("""echo -DFOO='a'"""),
305                          Pipeline([Command(['echo', '-DFOO=a'], [])], False))
306         self.assertEqual(self.parse('echo -DFOO="a"'),
307                          Pipeline([Command(['echo', '-DFOO=a'], [])], False))
308
309     def test_redirection(self):
310         self.assertEqual(self.parse('echo hello > c'),
311                          Pipeline([Command(['echo', 'hello'], 
312                                            [((('>'),), 'c')])], False))
313         self.assertEqual(self.parse('echo hello > c >> d'),
314                          Pipeline([Command(['echo', 'hello'], [(('>',), 'c'),
315                                                      (('>>',), 'd')])], False))
316         self.assertEqual(self.parse('a 2>&1'),
317                          Pipeline([Command(['a'], [(('>&',2), '1')])], False))
318
319     def test_pipeline(self):
320         self.assertEqual(self.parse('a | b'),
321                          Pipeline([Command(['a'], []),
322                                    Command(['b'], [])],
323                                   False))
324
325         self.assertEqual(self.parse('a | b | c'),
326                          Pipeline([Command(['a'], []),
327                                    Command(['b'], []),
328                                    Command(['c'], [])],
329                                   False))
330
331         self.assertEqual(self.parse('! a'),
332                          Pipeline([Command(['a'], [])],
333                                   True))
334
335     def test_list(self):        
336         self.assertEqual(self.parse('a ; b'),
337                          Seq(Pipeline([Command(['a'], [])], False),
338                              ';',
339                              Pipeline([Command(['b'], [])], False)))
340
341         self.assertEqual(self.parse('a & b'),
342                          Seq(Pipeline([Command(['a'], [])], False),
343                              '&',
344                              Pipeline([Command(['b'], [])], False)))
345
346         self.assertEqual(self.parse('a && b'),
347                          Seq(Pipeline([Command(['a'], [])], False),
348                              '&&',
349                              Pipeline([Command(['b'], [])], False)))
350
351         self.assertEqual(self.parse('a || b'),
352                          Seq(Pipeline([Command(['a'], [])], False),
353                              '||',
354                              Pipeline([Command(['b'], [])], False)))
355
356         self.assertEqual(self.parse('a && b || c'),
357                          Seq(Seq(Pipeline([Command(['a'], [])], False),
358                                  '&&',
359                                  Pipeline([Command(['b'], [])], False)),
360                              '||',
361                              Pipeline([Command(['c'], [])], False)))
362
363 if __name__ == '__main__':
364     unittest.main()