Bug fixes and switches
[IRC.git] / Robust / src / IR / Flat / BuildCode.java
1 package IR.Flat;
2 import IR.*;
3 import java.util.*;
4 import java.io.*;
5
6 public class BuildCode {
7     State state;
8     Hashtable temptovar;
9     Hashtable paramstable;
10     Hashtable tempstable;
11     int tag=0;
12     String localsprefix="___locals___";
13     String paramsprefix="___params___";
14     public static boolean GENERATEPRECISEGC=false;
15     public static String PREFIX="";
16
17     public BuildCode(State st, Hashtable temptovar) {
18         state=st;
19         this.temptovar=temptovar;
20         paramstable=new Hashtable();    
21         tempstable=new Hashtable();
22     }
23
24     public void buildCode() {
25         Iterator it=state.getClassSymbolTable().getDescriptorsIterator();
26         PrintWriter outclassdefs=null;
27         PrintWriter outstructs=null;
28         PrintWriter outmethodheader=null;
29         PrintWriter outmethod=null;
30         try {
31             OutputStream str=new FileOutputStream(PREFIX+"structdefs.h");
32             outstructs=new java.io.PrintWriter(str, true);
33             str=new FileOutputStream(PREFIX+"methodheaders.h");
34             outmethodheader=new java.io.PrintWriter(str, true);
35             str=new FileOutputStream(PREFIX+"classdefs.h");
36             outclassdefs=new java.io.PrintWriter(str, true);
37             str=new FileOutputStream(PREFIX+"methods.c");
38             outmethod=new java.io.PrintWriter(str, true);
39         } catch (Exception e) {
40             e.printStackTrace();
41             System.exit(-1);
42         }
43         outstructs.println("#include \"classdefs.h\"");
44         outmethodheader.println("#include \"structdefs.h\"");
45
46         // Output the C declarations
47         // These could mutually reference each other
48         while(it.hasNext()) {
49             ClassDescriptor cn=(ClassDescriptor)it.next();
50             outclassdefs.println("struct "+cn.getSafeSymbol()+";");
51         }
52         outclassdefs.println("");
53
54         it=state.getClassSymbolTable().getDescriptorsIterator();
55         while(it.hasNext()) {
56             ClassDescriptor cn=(ClassDescriptor)it.next();
57             generateCallStructs(cn, outclassdefs, outstructs, outmethodheader);
58         }
59         outstructs.close();
60         outmethodheader.close();
61
62         /* Build the actual methods */
63         outmethod.println("#include \"methodheaders.h\"");
64         outmethod.println("#include <runtime.h>");
65         Iterator classit=state.getClassSymbolTable().getDescriptorsIterator();
66         while(classit.hasNext()) {
67             ClassDescriptor cn=(ClassDescriptor)classit.next();
68             Iterator methodit=cn.getMethods();
69             while(methodit.hasNext()) {
70                 /* Classify parameters */
71                 MethodDescriptor md=(MethodDescriptor)methodit.next();
72                 FlatMethod fm=state.getMethodFlat(md);
73                 generateFlatMethod(fm,outmethod);
74             }
75         }
76         outmethod.close();
77     }
78
79     private void generateTempStructs(FlatMethod fm) {
80         MethodDescriptor md=fm.getMethod();
81         ParamsObject objectparams=new ParamsObject(md,tag++);
82         paramstable.put(md, objectparams);
83         for(int i=0;i<fm.numParameters();i++) {
84             TempDescriptor temp=fm.getParameter(i);
85             TypeDescriptor type=temp.getType();
86             if (type.isPtr()&&GENERATEPRECISEGC)
87                 objectparams.addPtr(temp);
88             else
89                 objectparams.addPrim(temp);
90         }
91
92         TempObject objecttemps=new TempObject(objectparams,md,tag++);
93         tempstable.put(md, objecttemps);
94         for(Iterator nodeit=fm.getNodeSet().iterator();nodeit.hasNext();) {
95             FlatNode fn=(FlatNode)nodeit.next();
96             TempDescriptor[] writes=fn.writesTemps();
97             for(int i=0;i<writes.length;i++) {
98                 TempDescriptor temp=writes[i];
99                 TypeDescriptor type=temp.getType();
100                 if (type.isPtr()&&GENERATEPRECISEGC)
101                     objecttemps.addPtr(temp);
102                 else
103                     objecttemps.addPrim(temp);
104             }
105         }
106     }
107
108     private void generateCallStructs(ClassDescriptor cn, PrintWriter classdefout, PrintWriter output, PrintWriter headersout) {
109         /* Output class structure */
110         Iterator fieldit=cn.getFields();
111         classdefout.println("struct "+cn.getSafeSymbol()+" {");
112         classdefout.println("  int type;");
113         while(fieldit.hasNext()) {
114             FieldDescriptor fd=(FieldDescriptor)fieldit.next();
115             if (fd.getType().isClass())
116                 classdefout.println("  struct "+fd.getType().getSafeSymbol()+" * "+fd.getSafeSymbol()+";");
117             else 
118                 classdefout.println("  "+fd.getType().getSafeSymbol()+" "+fd.getSafeSymbol()+";");
119         }
120         classdefout.println("};\n");
121
122         /* Cycle through methods */
123         Iterator methodit=cn.getMethods();
124         while(methodit.hasNext()) {
125             /* Classify parameters */
126             MethodDescriptor md=(MethodDescriptor)methodit.next();
127             FlatMethod fm=state.getMethodFlat(md);
128             generateTempStructs(fm);
129
130             ParamsObject objectparams=(ParamsObject) paramstable.get(md);
131             TempObject objecttemps=(TempObject) tempstable.get(md);
132
133             /* Output parameter structure */
134             if (GENERATEPRECISEGC) {
135                 output.println("struct "+cn.getSafeSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"_params {");
136                 output.println("  int type;");
137                 output.println("  void * next;");
138                 for(int i=0;i<objectparams.numPointers();i++) {
139                     TempDescriptor temp=objectparams.getPointer(i);
140                     output.println("  struct "+temp.getType().getSafeSymbol()+" * "+temp.getSafeSymbol()+";");
141                 }
142                 output.println("};\n");
143             }
144
145             /* Output temp structure */
146             if (GENERATEPRECISEGC) {
147                 output.println("struct "+cn.getSafeSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"_locals {");
148                 output.println("  int type;");
149                 output.println("  void * next;");
150                 for(int i=0;i<objecttemps.numPointers();i++) {
151                     TempDescriptor temp=objecttemps.getPointer(i);
152                     if (temp.getType().isNull())
153                         output.println("  void * "+temp.getSafeSymbol()+";");
154                     else
155                         output.println("  struct "+temp.getType().getSafeSymbol()+" * "+temp.getSafeSymbol()+";");
156                 }
157                 output.println("};\n");
158             }
159             
160             /* Output method declaration */
161             if (md.getReturnType()!=null) {
162                 if (md.getReturnType().isClass())
163                     headersout.print("struct " + md.getReturnType().getSafeSymbol()+" * ");
164                 else
165                     headersout.print(md.getReturnType().getSafeSymbol()+" ");
166             } else 
167                 //catch the constructor case
168                 headersout.print("void ");
169             headersout.print(cn.getSafeSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"(");
170             
171             boolean printcomma=false;
172             if (GENERATEPRECISEGC) {
173                 headersout.print("struct "+cn.getSafeSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"_params * "+paramsprefix);
174                 printcomma=true;
175             }
176             for(int i=0;i<objectparams.numPrimitives();i++) {
177                 TempDescriptor temp=objectparams.getPrimitive(i);
178                 if (printcomma)
179                     headersout.print(", ");
180                 printcomma=true;
181                 if (temp.getType().isClass())
182                     headersout.print("struct " + temp.getType().getSafeSymbol()+" * "+temp.getSafeSymbol());
183                 else
184                     headersout.print(temp.getType().getSafeSymbol()+" "+temp.getSafeSymbol());
185             }
186             headersout.println(");\n");
187         }
188     }
189
190     private void generateFlatMethod(FlatMethod fm, PrintWriter output) {
191         MethodDescriptor md=fm.getMethod();
192         ClassDescriptor cn=md.getClassDesc();
193         ParamsObject objectparams=(ParamsObject)paramstable.get(md);
194
195         generateHeader(md,output);
196         /* Print code */
197         output.println(" {");
198         
199         if (GENERATEPRECISEGC) {
200             output.println("   struct "+cn.getSafeSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"_locals "+localsprefix+";");
201         }
202         TempObject objecttemp=(TempObject) tempstable.get(md);
203         for(int i=0;i<objecttemp.numPrimitives();i++) {
204             TempDescriptor td=objecttemp.getPrimitive(i);
205             TypeDescriptor type=td.getType();
206             System.out.println(td.getSafeSymbol());
207             if (type.isNull())
208                 output.println("   void * "+td.getSafeSymbol()+";");
209             else if (type.isClass())
210                 output.println("   struct "+type.getSafeSymbol()+" * "+td.getSafeSymbol()+";");
211             else
212                 output.println("   "+type.getSafeSymbol()+" "+td.getSafeSymbol()+";");
213         }
214         
215
216         /* Generate labels first */
217         HashSet tovisit=new HashSet();
218         HashSet visited=new HashSet();
219         int labelindex=0;
220         Hashtable nodetolabel=new Hashtable();
221         tovisit.add(fm.methodEntryNode());
222         FlatNode current_node=null;
223
224         //Assign labels 1st
225         //Node needs a label if it is
226         while(!tovisit.isEmpty()) {
227             FlatNode fn=(FlatNode)tovisit.iterator().next();
228             tovisit.remove(fn);
229             visited.add(fn);
230             for(int i=0;i<fn.numNext();i++) {
231                 FlatNode nn=fn.getNext(i);
232                 if(i>0) {
233                     //1) Edge >1 of node
234                     nodetolabel.put(nn,new Integer(labelindex++));
235                 }
236                 if (!visited.contains(nn)&&!tovisit.contains(nn)) {
237                     tovisit.add(nn);
238                 } else {
239                     //2) Join point
240                     nodetolabel.put(nn,new Integer(labelindex++));
241                 }
242             }
243         }
244
245         //Do the actual code generation
246         tovisit=new HashSet();
247         visited=new HashSet();
248         tovisit.add(fm.methodEntryNode());
249         while(current_node!=null||!tovisit.isEmpty()) {
250             if (current_node==null) {
251                 current_node=(FlatNode)tovisit.iterator().next();
252                 tovisit.remove(current_node);
253             }
254             visited.add(current_node);
255             if (nodetolabel.containsKey(current_node))
256                 output.println("L"+nodetolabel.get(current_node)+":");
257             if (current_node.numNext()==0) {
258                 output.print("   ");
259                 generateFlatNode(fm, current_node, output);
260                 current_node=null;
261             } else if(current_node.numNext()==1) {
262                 output.print("   ");
263                 generateFlatNode(fm, current_node, output);
264                 FlatNode nextnode=current_node.getNext(0);
265                 if (visited.contains(nextnode)) {
266                     output.println("goto L"+nodetolabel.get(nextnode)+";");
267                     current_node=null;
268                 } else
269                     current_node=nextnode;
270             } else if (current_node.numNext()==2) {
271                 /* Branch */
272                 output.print("   ");
273                 generateFlatCondBranch(fm, (FlatCondBranch)current_node, "L"+nodetolabel.get(current_node.getNext(1)), output);
274                 if (!visited.contains(current_node.getNext(1)))
275                     tovisit.add(current_node.getNext(1));
276                 if (visited.contains(current_node.getNext(0))) {
277                     output.println("goto L"+nodetolabel.get(current_node.getNext(0))+";");
278                     current_node=null;
279                 } else
280                     current_node=current_node.getNext(0);
281             } else throw new Error();
282         }
283         output.println("}\n\n");
284     }
285
286     private String generateTemp(FlatMethod fm, TempDescriptor td) {
287         MethodDescriptor md=fm.getMethod();
288         TempObject objecttemps=(TempObject) tempstable.get(md);
289         if (objecttemps.isLocalPrim(td)||objecttemps.isParamPrim(td)) {
290             return td.getSafeSymbol();
291         }
292
293         if (objecttemps.isLocalPtr(td)) {
294             return localsprefix+"."+td.getSafeSymbol();
295         }
296
297         if (objecttemps.isParamPtr(td)) {
298             return paramsprefix+"->"+td.getSafeSymbol();
299         }
300         throw new Error();
301     }
302
303     private void generateFlatNode(FlatMethod fm, FlatNode fn, PrintWriter output) {
304         switch(fn.kind()) {
305         case FKind.FlatCall:
306             generateFlatCall(fm, (FlatCall) fn,output);
307             return;
308         case FKind.FlatFieldNode:
309             generateFlatFieldNode(fm, (FlatFieldNode) fn,output);
310             return;
311         case FKind.FlatSetFieldNode:
312             generateFlatSetFieldNode(fm, (FlatSetFieldNode) fn,output);
313             return;
314         case FKind.FlatNew:
315             generateFlatNew(fm, (FlatNew) fn,output);
316             return;
317         case FKind.FlatOpNode:
318             generateFlatOpNode(fm, (FlatOpNode) fn,output);
319             return;
320         case FKind.FlatCastNode:
321             generateFlatCastNode(fm, (FlatCastNode) fn,output);
322             return;
323         case FKind.FlatLiteralNode:
324             generateFlatLiteralNode(fm, (FlatLiteralNode) fn,output);
325             return;
326         case FKind.FlatReturnNode:
327             generateFlatReturnNode(fm, (FlatReturnNode) fn,output);
328             return;
329         case FKind.FlatNop:
330             output.println("/* nop */");
331             return;
332         }
333         throw new Error();
334
335     }
336
337     private void generateFlatCall(FlatMethod fm, FlatCall fc, PrintWriter output) {
338         MethodDescriptor md=fc.getMethod();
339         ParamsObject objectparams=(ParamsObject) paramstable.get(md);
340         ClassDescriptor cn=md.getClassDesc();
341         output.println("{");
342         if (GENERATEPRECISEGC) {
343             output.print("       struct "+cn.getSafeSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"_params __parameterlist__={");
344             
345             output.print(objectparams.getUID());
346             output.print(", & "+localsprefix);
347             if (fc.getThis()!=null) {
348                 output.print(", ");
349                 output.print(generateTemp(fm,fc.getThis()));
350             }
351             for(int i=0;i<fc.numArgs();i++) {
352                 VarDescriptor var=md.getParameter(i);
353                 TempDescriptor paramtemp=(TempDescriptor)temptovar.get(var);
354                 if (objectparams.isParamPtr(paramtemp)) {
355                     TempDescriptor targ=fc.getArg(i);
356                     output.print(", ");
357                     output.print(generateTemp(fm, targ));
358                 }
359             }
360             output.println("};");
361         }
362         output.print("       ");
363
364         /* TODO: Virtual dispatch */
365         if (fc.getReturnTemp()!=null)
366             output.print(generateTemp(fm,fc.getReturnTemp())+"=");
367         output.print(cn.getSafeSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"(");
368         boolean needcomma=false;
369         if (GENERATEPRECISEGC) {
370             output.print("&__parameterlist__");
371             needcomma=true;
372         } else {
373             if (fc.getThis()!=null) {
374                 output.print(generateTemp(fm,fc.getThis()));
375                 needcomma=true;
376             }
377         }
378         for(int i=0;i<fc.numArgs();i++) {
379             VarDescriptor var=md.getParameter(i);
380             TempDescriptor paramtemp=(TempDescriptor)temptovar.get(var);
381             if (objectparams.isParamPrim(paramtemp)) {
382                 TempDescriptor targ=fc.getArg(i);
383                 if (needcomma)
384                     output.print(", ");
385                 output.print(generateTemp(fm, targ));
386                 needcomma=true;
387             }
388         }
389         output.println(");");
390         output.println("   }");
391     }
392
393     private void generateFlatFieldNode(FlatMethod fm, FlatFieldNode ffn, PrintWriter output) {
394         output.println(generateTemp(fm, ffn.getDst())+"="+ generateTemp(fm,ffn.getSrc())+"->"+ ffn.getField().getSafeSymbol()+";");
395     }
396
397     private void generateFlatSetFieldNode(FlatMethod fm, FlatSetFieldNode fsfn, PrintWriter output) {
398         output.println(generateTemp(fm, fsfn.getDst())+"->"+ fsfn.getField().getSafeSymbol()+"="+ generateTemp(fm,fsfn.getSrc())+";");
399     }
400
401     private void generateFlatNew(FlatMethod fm, FlatNew fn, PrintWriter output) {
402         output.println(generateTemp(fm,fn.getDst())+"=allocate_new("+fn.getType().getClassDesc().getId()+");");
403     }
404
405     private void generateFlatOpNode(FlatMethod fm, FlatOpNode fon, PrintWriter output) {
406
407         if (fon.getRight()!=null)
408             output.println(generateTemp(fm, fon.getDest())+" = "+generateTemp(fm, fon.getLeft())+fon.getOp().toString()+generateTemp(fm,fon.getRight())+";");
409         else if (fon.getOp().getOp()==Operation.ASSIGN)
410             output.println(generateTemp(fm, fon.getDest())+" = "+generateTemp(fm, fon.getLeft())+";");
411         else if (fon.getOp().getOp()==Operation.UNARYPLUS)
412             output.println(generateTemp(fm, fon.getDest())+" = "+generateTemp(fm, fon.getLeft())+";");
413         else if (fon.getOp().getOp()==Operation.UNARYMINUS)
414             output.println(generateTemp(fm, fon.getDest())+" = -"+generateTemp(fm, fon.getLeft())+";");
415         else if (fon.getOp().getOp()==Operation.POSTINC)
416             output.println(generateTemp(fm, fon.getDest())+" = "+generateTemp(fm, fon.getLeft())+"++;");
417         else if (fon.getOp().getOp()==Operation.POSTDEC)
418             output.println(generateTemp(fm, fon.getDest())+" = "+generateTemp(fm, fon.getLeft())+"--;");
419         else if (fon.getOp().getOp()==Operation.PREINC)
420             output.println(generateTemp(fm, fon.getDest())+" = ++"+generateTemp(fm, fon.getLeft())+";");
421         else if (fon.getOp().getOp()==Operation.PREDEC)
422             output.println(generateTemp(fm, fon.getDest())+" = --"+generateTemp(fm, fon.getLeft())+";");
423         else
424             output.println(generateTemp(fm, fon.getDest())+fon.getOp().toString()+generateTemp(fm, fon.getLeft())+";");
425     }
426
427     private void generateFlatCastNode(FlatMethod fm, FlatCastNode fcn, PrintWriter output) {
428         /* TODO: Make call into runtime */
429         output.println(generateTemp(fm,fcn.getDst())+"=("+fcn.getType().getSafeSymbol()+")"+generateTemp(fm,fcn.getSrc())+";");
430     }
431
432     private void generateFlatLiteralNode(FlatMethod fm, FlatLiteralNode fln, PrintWriter output) {
433         if (fln.getValue()==null)
434             output.println(generateTemp(fm, fln.getDst())+"=0;");
435         else if (fln.getType().getSymbol().equals(TypeUtil.StringClass))
436             output.println(generateTemp(fm, fln.getDst())+"=newstring(\""+FlatLiteralNode.escapeString((String)fln.getValue())+"\");");
437         else
438             output.println(generateTemp(fm, fln.getDst())+"="+fln.getValue()+";");
439     }
440
441     private void generateFlatReturnNode(FlatMethod fm, FlatReturnNode frn, PrintWriter output) {
442         output.println("return "+generateTemp(fm, frn.getReturnTemp())+";");
443     }
444
445     private void generateFlatCondBranch(FlatMethod fm, FlatCondBranch fcb, String label, PrintWriter output) {
446         output.println("if (!"+generateTemp(fm, fcb.getTest())+") goto "+label+";");
447     }
448
449     private void generateHeader(MethodDescriptor md, PrintWriter output) {
450         /* Print header */
451         ParamsObject objectparams=(ParamsObject)paramstable.get(md);
452         ClassDescriptor cn=md.getClassDesc();
453         
454         if (md.getReturnType()!=null) {
455             if (md.getReturnType().isClass())
456                 output.print("struct " + md.getReturnType().getSafeSymbol()+" * ");
457             else
458                 output.print(md.getReturnType().getSafeSymbol()+" ");
459         } else 
460             //catch the constructor case
461             output.print("void ");
462
463         output.print(cn.getSafeSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"(");
464         
465         boolean printcomma=false;
466         if (GENERATEPRECISEGC) {
467             output.print("struct "+cn.getSafeSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"_params * "+paramsprefix);
468             printcomma=true;
469         }
470         for(int i=0;i<objectparams.numPrimitives();i++) {
471             TempDescriptor temp=objectparams.getPrimitive(i);
472             if (printcomma)
473                 output.print(", ");
474             printcomma=true;
475             if (temp.getType().isClass())
476                 output.print("struct "+temp.getType().getSafeSymbol()+" * "+temp.getSafeSymbol());
477             else
478                 output.print(temp.getType().getSafeSymbol()+" "+temp.getSafeSymbol());
479         }
480         output.print(")");
481     }
482 }