首页 > 编程语言 >tensorflow1.x——如何在python多线程中调用同一个session会话

tensorflow1.x——如何在python多线程中调用同一个session会话

时间:2022-11-03 00:25:32浏览次数:69  
标签:__ sess tensorflow1 thread python session tf 多线程

如何在python多线程中调用同一个session会话?

 

这个问题源于我在看的一个强化学习代码:

https://gitee.com/devilmaycry812839668/scalable_agent

 

在众多的机器学习的分支中,凡是使用深度学习模型的基本都不会考虑本文的这个问题,不论是监督学习还是非监督学习,其学习过程都可以使用串行计算的算法逻辑来解决,但是唯独强化学习是个例外。

强化学习可以使用串行方式计算,但是由于其有采样这个操作并且采样效率还比较低下,因此在强化学习的并行采样中就出现了本文的这个问题,或者更具体的是说在一个进程的多个线程中是否可以调用同一个session,这个主进程可以是C++的也可以是python的,这个session可以是TensorFlow的同样也可以是pytorch的,为了不过大扩张本文的讨论范围,本文的设定是只讨论python下多线程对TensorFlow的同一个session会话的调用情况,不过也计划在本文后再做一个C++下多进程对TensorFlow同一个session会话的调用情况研究。

 

 

---------------------------------------

 

 

先从一个外网的poster来看这个问题:

 

[英] Reusing Tensorflow session in multiple threads causes crash(在多个线程中重用Tensorflow会话会导致崩溃)(中文翻译版本)

上文同样是从强化学习的应用背景出发,原文作者给出了一个python多线程调用统一TensorFlow会话session的代码:

import tensorflow as tf

import threading

def thread_function(sess, i):
    inn = [1.3, 4.5]
    A = tf.placeholder(dtype=float, shape=(None), name="input")
    P = tf.Print(A, [A])
    Q = tf.add(A, P)
    sess.run(Q, feed_dict={A: inn})

def main(sess):

    thread_list = []
    for i in range(0, 4):
        t = threading.Thread(target=thread_function, args=(sess, i))
        thread_list.append(t)
        t.start()

    for t in thread_list:
        t.join()

if __name__ == '__main__':

    sess = tf.Session()
    main(sess)
View Code

 

该代码很不幸,不能运行,报错:RuntimeError: The Session graph is empty. Add operations to the graph before calling run().

 

 

 

 

 

根据poster中回帖给出的解决方法可以知道,在主线程中Session启动后调用的计算图graph默认就是默认Graph,但是在子线程中则需要对使用的计算图进行指定,给出修改后的可运行的代码:

import tensorflow as tf
import threading

def thread_function(i):
    with sess.graph.as_default():
        inn = [1.3, 4.5]
        A = tf.placeholder(dtype=float, shape=(None), name="input")
        P = tf.Print(A, [A, "hello:"])
        Q = tf.add(A, P)
        # print(sess.run(Q, feed_dict={A: inn}))
        sess.run(Q, feed_dict={A: inn})

def main(sess):
    thread_list = []
    for i in range(0, 4):
        t = threading.Thread(target=thread_function, args=(i,))
        thread_list.append(t)
        t.start()

    for t in thread_list:
        t.join()

if __name__ == '__main__':
    sess = tf.Session()
    main(sess)
View Code

 

 

 

 

 

不过需要注意的是由于python中的GIL,因此python中线程是不能并发的,也就是说同一时刻多个线程中只能有一个线程在运行,因此即使多个python线程调用同一个session会话,其总的用时是单线程运行时间的累加,并不能起到加速的作用,为了更清晰的验证给出下面代码:

import tensorflow as tf
import numpy as np
import threading
import time

def thread_function(sess, Q, A, inn):
    with sess.graph.as_default():
        sess.run(Q, feed_dict={A: inn})

def main(sess):
    A = tf.placeholder(dtype=float, shape=(500, 500), name="input")
    Q = tf.Variable(tf.random_normal(shape=(500, 500)), dtype=float, name="input_variable")
    Q = tf.add(A, Q)
    for _ in range(10000):
        Q += tf.matmul(A, Q)
        # print(sess.run(Q, feed_dict={A: inn}))

    inn = np.random.random(size=(500, 500))
    thread_list = []
    a_time = time.time()

    sess.run(tf.global_variables_initializer())

    print(a_time)
    for i in range(0, 3):
        t = threading.Thread(target=thread_function, args=(sess, Q, A, inn))
        thread_list.append(t)
        t.start()

    for t in thread_list:
        t.join()
    b_time = time.time()
    print(b_time)
    print(b_time-a_time)

if __name__ == '__main__':
    sess = tf.Session()
    main(sess)

 

上面代码,在python中开三个线程调用相同的TensorFlow的Session会话,结果:

 

 

 

 

改成两个线程:

 

 

 

 

 

改成一个线程:

 

 

 

 

可以看到虽然可以使用python多线程调用同一个TensorFlow的session会话,但是并不能对性能有什么提升,多线程运算其实也是串行的,或许只有少量提升,其主要原因就是python的多线程其实不能并发运行的。

 

 

 

===========================================

 

标签:__,sess,tensorflow1,thread,python,session,tf,多线程
From: https://www.cnblogs.com/devilmaycry812839668/p/16853030.html

相关文章

  • tensorflow1.x——如何在C++多线程中调用同一个session会话
    相关内容:tensorflow1.x——如何在python多线程中调用同一个session会话 =================================================......
  • 多线程多进程拷贝文件Linux&c
    多进程拷贝文件1.Linux环境中,c语言我们利用的是fork()函数来创建新进程,通过wait()和waitpid()等函数来等待阻塞进程,通过exit()函数来结束进程。2.我在单进程中,用的是whil......
  • 多线程中的wait与join
        wait一个Object的方法,目的是将调用obj.wait()的线程置为waiting的状态,等待其他线程调用obj.notify()或者obj.notifyAll()来唤醒。最常见的就算生产者/......
  • Java多线程-ThreadLocal(六)
    为了提高CPU的利用率,工程师们创造了多线程。但是线程们说:要有光!(为了减少线程创建(T1启动)和销毁(T3切换)的时间),于是工程师们又接着创造了线程池ThreadPool。就这样就可以了吗?—......
  • java的多线程实现方式以及对应的线程锁实现
    一、多线程的实现1.1继承Thread类继承:packagecom.yuan.yk.ThreadLearn;importstaticcom.yuan.yk.ThreadLearn.func1.doSomething;publicclassThreadFuncextends......
  • aws 通过Session Manager 进入服务器
    aws服务器默认不可以像阿里云,azure等一样通过VNC或者serialconsole登录到服务器笔者在此记录aws通过SessionManager进入服务器的流程安装ssh-agent例如Debian系统,......
  • Pyhton多线程
    多线程介绍什么是线程​ 线程(Thread)也叫轻量级进程,是操作系统能够进行运算调度的最小单位,它被包涵在进程之中,是进程中的实际运作单位。​ 个线程可以创建和撤消另一个......
  • Java多线程(7):JUC(下)
    您好,我是湘王,这是我的博客园,欢迎您来,欢迎您再来~ 除了四种常见的同步器(发令枪、摇号器、栅栏和交换机),JUC还有所谓线程安全的容器、阻塞队列和一些特殊的类。其中常出现的......
  • Spring Session
    Session会话管理概述Web中的Session和Cookie回顾Session机制由于HTTP协议是无状态的协议,一次浏览器和服务器的交互过程就是:浏览器:你好吗?服务器:很好!这就是一次......
  • 多线程 & 反射 & 注解 & JDBC 核心点总结
    多线程核心点:线程安全创建线程的两种方式线程生命周期获取、修改线程名获取当前线程对象静态方法sleep()通过异常终止线程的睡眠interrupt()强行终止线程合理......