如何在python多线程中调用同一个session会话?
这个问题源于我在看的一个强化学习代码:
https://gitee.com/devilmaycry812839668/scalable_agent
在众多的机器学习的分支中,凡是使用深度学习模型的基本都不会考虑本文的这个问题,不论是监督学习还是非监督学习,其学习过程都可以使用串行计算的算法逻辑来解决,但是唯独强化学习是个例外。
强化学习可以使用串行方式计算,但是由于其有采样这个操作并且采样效率还比较低下,因此在强化学习的并行采样中就出现了本文的这个问题,或者更具体的是说在一个进程的多个线程中是否可以调用同一个session,这个主进程可以是C++的也可以是python的,这个session可以是TensorFlow的同样也可以是pytorch的,为了不过大扩张本文的讨论范围,本文的设定是只讨论python下多线程对TensorFlow的同一个session会话的调用情况,不过也计划在本文后再做一个C++下多进程对TensorFlow同一个session会话的调用情况研究。
---------------------------------------
先从一个外网的poster来看这个问题:
[英] Reusing Tensorflow session in multiple threads causes crash(在多个线程中重用Tensorflow会话会导致崩溃)(中文翻译版本)
上文同样是从强化学习的应用背景出发,原文作者给出了一个python多线程调用统一TensorFlow会话session的代码:
import tensorflow as tf
import threading
def thread_function(sess, i):
inn = [1.3, 4.5]
A = tf.placeholder(dtype=float, shape=(None), name="input")
P = tf.Print(A, [A])
Q = tf.add(A, P)
sess.run(Q, feed_dict={A: inn})
def main(sess):
thread_list = []
for i in range(0, 4):
t = threading.Thread(target=thread_function, args=(sess, i))
thread_list.append(t)
t.start()
for t in thread_list:
t.join()
if __name__ == '__main__':
sess = tf.Session()
main(sess)
View Code
该代码很不幸,不能运行,报错:RuntimeError: The Session graph is empty. Add operations to the graph before calling run().
根据poster中回帖给出的解决方法可以知道,在主线程中Session启动后调用的计算图graph默认就是默认Graph,但是在子线程中则需要对使用的计算图进行指定,给出修改后的可运行的代码:
import tensorflow as tf
import threading
def thread_function(i):
with sess.graph.as_default():
inn = [1.3, 4.5]
A = tf.placeholder(dtype=float, shape=(None), name="input")
P = tf.Print(A, [A, "hello:"])
Q = tf.add(A, P)
# print(sess.run(Q, feed_dict={A: inn}))
sess.run(Q, feed_dict={A: inn})
def main(sess):
thread_list = []
for i in range(0, 4):
t = threading.Thread(target=thread_function, args=(i,))
thread_list.append(t)
t.start()
for t in thread_list:
t.join()
if __name__ == '__main__':
sess = tf.Session()
main(sess)
View Code
不过需要注意的是由于python中的GIL,因此python中线程是不能并发的,也就是说同一时刻多个线程中只能有一个线程在运行,因此即使多个python线程调用同一个session会话,其总的用时是单线程运行时间的累加,并不能起到加速的作用,为了更清晰的验证给出下面代码:
import tensorflow as tf
import numpy as np
import threading
import time
def thread_function(sess, Q, A, inn):
with sess.graph.as_default():
sess.run(Q, feed_dict={A: inn})
def main(sess):
A = tf.placeholder(dtype=float, shape=(500, 500), name="input")
Q = tf.Variable(tf.random_normal(shape=(500, 500)), dtype=float, name="input_variable")
Q = tf.add(A, Q)
for _ in range(10000):
Q += tf.matmul(A, Q)
# print(sess.run(Q, feed_dict={A: inn}))
inn = np.random.random(size=(500, 500))
thread_list = []
a_time = time.time()
sess.run(tf.global_variables_initializer())
print(a_time)
for i in range(0, 3):
t = threading.Thread(target=thread_function, args=(sess, Q, A, inn))
thread_list.append(t)
t.start()
for t in thread_list:
t.join()
b_time = time.time()
print(b_time)
print(b_time-a_time)
if __name__ == '__main__':
sess = tf.Session()
main(sess)
上面代码,在python中开三个线程调用相同的TensorFlow的Session会话,结果:
改成两个线程:
改成一个线程:
可以看到虽然可以使用python多线程调用同一个TensorFlow的session会话,但是并不能对性能有什么提升,多线程运算其实也是串行的,或许只有少量提升,其主要原因就是python的多线程其实不能并发运行的。
===========================================
标签:__,sess,tensorflow1,thread,python,session,tf,多线程 From: https://blog.51cto.com/u_15642578/5857002