我的上一篇文章 Java 运行python程序 中提到了我想用java调用python进行机器学习,至于为什么需要java来调用,主要是因为评估模型的一个开源库是采用java实现的。
但是这之中还存在一个问题,那就是每次加载模型都需要很多时间,除非为是批处理所有数据,才可以只加载一次模型。而实际上,模拟过程中的数据可能不会同时出现,那么就会反复调用python文件,而每次重新调用就会加载一次模型,最终可能会影响评估的结果。
String as[] = {"1,2,3", "1,2,3", "1,2,3"};
System.out.println(pythonRun.run(pyPath, as));
for (int i = 0; i < 3; i++){
String as[] = {"1,2,3"};
System.out.println(pythonRun.run(pyPath, as));
}
以上两段代码运行结果相同,但是花费的时间却差很多,前者批处理,实际是一次加载模型的时间,而后者则是三次。
而假设我能够改变Python程序的设计,让它能够多次输入输出,就可以避免重新加载模型。而这就需要java的交互。
核心代码如下:
public class PythonRun{
private String environment = "python";
private String root = null;
private String cache = "cache/";
private boolean autoRemoveCache = true;
public static class AResult{
private PythonRun pythonRun;
private Process process;
private String path;
private BufferedWriter out;
private BufferedReader in;
public AResult(PythonRun pythonRun, Process process, String path) throws UnsupportedEncodingException {
this.pythonRun = pythonRun;
this.process = process;
this.path = path;
out = new BufferedWriter(new OutputStreamWriter(process.getOutputStream()));
in = new BufferedReader(new InputStreamReader(process.getInputStream()));
}
public void close(){
if (pythonRun.autoRemoveCache && path != null)
new File(path).delete();
process.destroy();
}
public void input(String message){
out.write(message+"\n");
out.flush();
}
public String getResult() throws Exception{
String line;
StringBuilder result = new StringBuilder();
do {
line = in.readLine();
result.append(line).append("\n");
} while (in.ready());
return result.toString();
}
}
public AResult asyncRun(String path, String ...args) throws IOException {
path = createNewPy(path);
List<String> inputArgs = new LinkedList<>(Arrays.asList(environment, path));
inputArgs.addAll(Arrays.asList(args));
Process process = Runtime.getRuntime().exec(inputArgs.toArray(new String[0]));
return new AResult(this, process, path);
}
}
使用效果如下:
测试的java代码:
String pyPath = "E:\\pythonProject\\MEC-Study\\src\\test\\testForComm.py";
String pyEnvironment = "E:\\Anaconda3\\envs\\MEC-Study\\python.exe";
PythonRun pythonRun = new PythonRun();
pythonRun.setEnvironment(pyEnvironment);
pythonRun.setRoot("E:\\pythonProject\\MEC-Study\\src");
PythonRun.AResult aResult = pythonRun.asyncRun(pyPath);
aResult.input("a");
System.out.println(aResult.getResult());
aResult.input("b");
System.out.println(aResult.getResult());
aResult.input("exit");
System.out.println(aResult.getResult());
aResult.close();
测试的python的代码:
if __name__ == '__main__':
message = input()
while message != "exit":
print(message)
message = input()
运行结果:
a
b
null
Process finished with exit code 0
所有代码:
import java.io.*;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
public class PythonRun {
private String environment = "python";
private String root = null;
private String cache = "cache/";
private boolean autoRemoveCache = true;
public String run(String path, String ...args) throws IOException {
path = createNewPy(path);
List<String> inputArgs = new LinkedList<>(Arrays.asList(environment, path));
inputArgs.addAll(Arrays.asList(args));
StringBuilder result = new StringBuilder();
try {
Process proc = Runtime.getRuntime().exec(inputArgs.toArray(new String[0]));
BufferedReader in = new BufferedReader(new InputStreamReader(proc.getInputStream()));
String line;
while ((line = in.readLine()) != null) {
result.append(line).append("\n");
}
in.close();
proc.waitFor();
} catch (Exception e) {
e.printStackTrace();
}
if (autoRemoveCache && path != null)
new File(path).delete();
return result.toString();
}
public static class AResult{
private PythonRun pythonRun;
private Process process;
private String path;
private BufferedWriter out;
private BufferedReader in;
public AResult(PythonRun pythonRun, Process process, String path) {
this.pythonRun = pythonRun;
this.process = process;
this.path = path;
out = new BufferedWriter(new OutputStreamWriter(process.getOutputStream()));
in = new BufferedReader(new InputStreamReader(process.getInputStream()));
}
public void close() {
if (pythonRun.autoRemoveCache && path != null)
new File(path).delete();
process.destroy();
}
public void input(String message) throws IOException {
out.write(message+"\n");
out.flush();
}
public String getResult() throws Exception{
String line;
StringBuilder result = new StringBuilder();
do {
line = in.readLine();
result.append(line).append("\n");
} while (in.ready());
return result.toString();
}
}
public AResult asyncRun(String path, String ...args) throws IOException {
path = createNewPy(path);
List<String> inputArgs = new LinkedList<>(Arrays.asList(environment, path));
inputArgs.addAll(Arrays.asList(args));
Process process = Runtime.getRuntime().exec(inputArgs.toArray(new String[0]));
return new AResult(this, process, path);
}
private String createNewPy(String path) {
File file = new File(path);
if (file.isFile()){
String result = loadTxt(file);
if (root != null){
result = "import sys\n" +
"sys.path.append(\"" + root + "\")\n" + result;
}
String save = cache + file.getName();
saveTxt(save, result);
return save;
}
return null;
}
private static File saveTxt(String filename, String string){
File file = new File(filename);
try {
BufferedWriter out = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file),"UTF-8"));
out.write(string);
out.flush();
out.close();
} catch (Exception e) {
e.printStackTrace();
}
return file;
}
private String loadTxt(File file){
StringBuilder result = new StringBuilder();
try {
BufferedReader in = new BufferedReader(new InputStreamReader(new FileInputStream(file), "UTF-8"));
String str;
while ((str = in.readLine()) != null) {
result.append(str).append("\n");
}
}catch (Exception e){
e.printStackTrace();
}
return result.toString();
}
public String getCache() {
return cache;
}
public void setCache(String cache) {
this.cache = cache;
}
public String getEnvironment() {
return environment;
}
public void setEnvironment(String environment) {
this.environment = environment;
}
public String getRoot() {
return root;
}
public void setRoot(String root) {
this.root = root;
}
public boolean isAutoRemoveCache() {
return autoRemoveCache;
}
public void setAutoRemoveCache(boolean autoRemoveCache) {
this.autoRemoveCache = autoRemoveCache;
}
}
|