Adding sypet to repo
[Benchmarks_CSolver.git] / sypet-non-incremental / src / edu / utexas / sypet / synthesis / SypetTestUtil.java
diff --git a/sypet-non-incremental/src/edu/utexas/sypet/synthesis/SypetTestUtil.java b/sypet-non-incremental/src/edu/utexas/sypet/synthesis/SypetTestUtil.java
new file mode 100644 (file)
index 0000000..c75af9b
--- /dev/null
@@ -0,0 +1,246 @@
+/*
+ * Copyright (C) 2017 The SyPet Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.utexas.sypet.synthesis;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.io.PrintWriter;
+import java.io.Writer;
+import java.lang.reflect.Method;
+import java.net.MalformedURLException;
+import java.net.URI;
+import java.net.URL;
+import java.net.URLClassLoader;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import javax.tools.DiagnosticListener;
+import javax.tools.FileObject;
+import javax.tools.ForwardingJavaFileManager;
+import javax.tools.JavaCompiler;
+import javax.tools.JavaFileObject;
+import javax.tools.SimpleJavaFileObject;
+import javax.tools.StandardJavaFileManager;
+import javax.tools.ToolProvider;
+
+import edu.utexas.sypet.synthesis.model.Benchmark;
+import edu.utexas.sypet.util.TimeUtil;
+
+public class SypetTestUtil {
+       public static final boolean DISPLAY_ERROR = false;
+       public static final String CLASSNAME = "Source";
+       protected static double mTimeCompile = 0;
+       protected static double mTimeRun = 0;
+       protected static boolean mCompilationSuccess = true;
+
+       public static boolean runTest(Benchmark bench) {
+               long start1 = System.nanoTime();
+               Class<?> compiledClass = compileClass(bench);
+               long end1 = System.nanoTime();
+               mTimeCompile = TimeUtil.computeTime(start1, end1);
+               if (compiledClass == null) {
+                       mTimeRun = 0.0;
+                       return false;
+               }
+               long start2 = System.nanoTime();
+               boolean success = false;
+               try{
+                       Method method = compiledClass.getMethod("test");
+                       //System.out.println("RESULT=" + (String) method.invoke(null));
+                       success = (boolean) method.invoke(null);
+               } catch (Exception e) {
+                       if (DISPLAY_ERROR) e.printStackTrace();
+                       success = false;
+               }
+               long end2 = System.nanoTime();
+               mTimeRun = TimeUtil.computeTime(start2, end2);
+               return success;
+       }
+
+       public static double getCompilationTime() {
+               return mTimeCompile;
+       }
+
+       public static double getRunningTime() {
+               return mTimeRun;
+       }
+
+       public static boolean isCompilationSuccess() {
+               return mCompilationSuccess;
+       }
+
+       @SuppressWarnings("rawtypes")
+       private static Class compileClass(Benchmark bench) {
+               String program = genProgram(bench);
+               if (DISPLAY_ERROR) System.out.println(program);
+               String classpath = genClassPath(bench);
+               try{
+                       @SuppressWarnings("restriction")
+//                     JavaCompiler javac = new EclipseCompiler();
+                       JavaCompiler javac = ToolProvider.getSystemJavaCompiler();
+
+                       StandardJavaFileManager sjfm = javac.getStandardFileManager(null, null, null);
+                       SpecialClassLoader cl = new SpecialClassLoader(bench);
+                       SpecialJavaFileManager fileManager = new SpecialJavaFileManager(sjfm, cl);
+
+                       List<String> options = new ArrayList<String>();
+                       options.add("-cp");
+                       options.add(classpath);
+                       List compilationUnits = Arrays.asList(new MemorySource(CLASSNAME, program));
+                       DiagnosticListener dianosticListener = null;
+                       Iterable classes = null;
+                       Writer out = DISPLAY_ERROR ? new PrintWriter(System.err) : null;
+                       JavaCompiler.CompilationTask compile = javac.getTask(out, fileManager, dianosticListener, options, classes, compilationUnits);
+                       mCompilationSuccess = compile.call();
+                       if (mCompilationSuccess) return cl.findClass(CLASSNAME);
+               } catch (Exception e){
+                       if (DISPLAY_ERROR) e.printStackTrace();
+               }
+               return null;
+       }
+
+       private static String genProgram(Benchmark bench) {
+               StringBuilder builder = new StringBuilder();
+               builder.append("public class ").append(CLASSNAME).append("{\n");
+               builder.append("public static ").append(bench.getMethodHeader()).append(" throws Throwable {\n");
+               builder.append(bench.getBody());
+               builder.append("}\n");
+               builder.append(bench.getTestBody());
+               builder.append("}\n");
+               return builder.toString();
+       }
+
+       private static String genClassPath(Benchmark bench) {
+               StringBuilder builder = new StringBuilder();
+               List<String> libs = bench.getLibs();
+               for (String lib : libs) {
+                       builder.append(lib);
+                       builder.append(':');
+               }
+               builder.append('.');
+               return builder.toString();
+       }
+
+}
+
+class MemorySource extends SimpleJavaFileObject {
+       private String src;
+       public MemorySource(String name, String src) {
+               super(URI.create("file:///" + name + ".java"), Kind.SOURCE);
+               this.src = src;
+       }
+       public CharSequence getCharContent(boolean ignoreEncodingErrors) {
+               return src;
+       }
+       public OutputStream openOutputStream() {
+               throw new IllegalStateException();
+       }
+       public InputStream openInputStream() {
+               return new ByteArrayInputStream(src.getBytes());
+       }
+}
+
+class SpecialJavaFileManager extends ForwardingJavaFileManager {
+       private SpecialClassLoader xcl;
+
+       public SpecialJavaFileManager(StandardJavaFileManager sjfm, SpecialClassLoader xcl) {
+               super(sjfm);
+               this.xcl = xcl;
+       }
+
+       public JavaFileObject getJavaFileForOutput(Location location, String name, JavaFileObject.Kind kind, FileObject sibling) throws IOException {
+               MemoryByteCode mbc = new MemoryByteCode(name);
+               xcl.addClass(name, mbc);
+               return mbc;
+       }
+
+       public ClassLoader getClassLoader(Location location) {
+               return xcl;
+       }
+}
+
+class MemoryByteCode extends SimpleJavaFileObject {
+       private ByteArrayOutputStream baos;
+
+       public MemoryByteCode(String name) {
+               super(URI.create("byte:///" + name + ".class"), Kind.CLASS);
+       }
+
+       public CharSequence getCharContent(boolean ignoreEncodingErrors) {
+               throw new IllegalStateException();
+       }
+
+       public OutputStream openOutputStream() {
+               baos = new ByteArrayOutputStream();
+               return baos;
+       }
+
+       public InputStream openInputStream() {
+               throw new IllegalStateException();
+       }
+
+       public byte[] getBytes() {
+               return baos.toByteArray();
+       }
+}
+
+class SpecialClassLoader extends ClassLoader {
+       protected Map<String, MemoryByteCode> map = new HashMap<String, MemoryByteCode>();
+       protected Benchmark mBench;
+       protected URLClassLoader cl = null;
+
+       public SpecialClassLoader(Benchmark bench) {
+               mBench = bench;
+       }
+
+       @Override
+       protected Class<?> findClass(String name) throws ClassNotFoundException {
+               MemoryByteCode mbc = map.get(name);
+               if (mbc == null){
+                       URL[] urls = getUrls(mBench);
+                       if (cl == null) {
+                               cl = new URLClassLoader(urls, Thread.currentThread().getContextClassLoader());
+                       }
+                       return cl.loadClass(name);
+               } else {
+                       return defineClass(name, mbc.getBytes(), 0, mbc.getBytes().length);
+               }
+       }
+
+       public void addClass(String name, MemoryByteCode mbc) {
+               map.put(name, mbc);
+       }
+
+       protected URL[] getUrls(Benchmark bench) {
+               ArrayList<String> libs = new ArrayList<String>(bench.getLibs());
+               URL[] urls = new URL[libs.size()];
+               try {
+                       for (int i = 0; i < libs.size(); ++i) {
+                               urls[i] = new File(libs.get(i)).toURI().toURL();
+                       }
+               } catch (MalformedURLException e) {
+                       e.printStackTrace();
+               }
+               return urls;
+       }
+}