ThreadPoolExecutor源码分析

Python 小记 2019-02-18 8226 字 1677 浏览 点赞

前言

Python中,ThreadPoolExecutor对Thread做了进一步封装。在Thread基础之上,使得多线程开发更简单了。另一方面,由于还存在ProcessPoolExecutor类,多线程与多进程的开发接口得到了统一。

在整个过程中,需要理清ThreadPoolExecutor的成员方法Future的成员方法

ThreadPoolExecutor

一个简单的多线程

ThreadPoolExecutor在concurrent.futures模块下,一个简单的多线程代码如下:

import time
from concurrent.futures import ThreadPoolExecutor

def print_hello():
    for i in range(10):
        time.sleep(1)
        print("h{} hello".format(i))

def print_world():
    for i in range(10):
        time.sleep(1)
        print("w{} world".format(i))

if __name__ == "__main__":
    executor = ThreadPoolExecutor(max_workers=2)
    task1th = executor.submit(print_hello)
    task2ed = executor.submit(print_world)

# 输出:
h0 hello
w0 world
h1 hello
w1 world
h2 hello
...

可以见到,打印结果中的“hello”和“world”是交叉出现,这符合多线程行为。

submit

在上面的demo中,ThreadPoolExecutor(max_workers=2)表示创建一个线程池,而它的管理人就是这里的实例对象executor,executor有一个submit()方法,用来提交子线程需要执行的任务——在这里分别是函数print_hello()和函数print_world(),每个任务对应一个线程。跟threading.Thread()不同,你不需要用什么命令让它“动”起来(threading.Thread()中需要start()),当你submit之后,子线程就去执行了。

下面是submit()方法的源码:

# submit源码
class ThreadPoolExecutor(_base.Executor):
    ...
    def submit(self, fn, *args, **kwargs):
        with self._shutdown_lock:
            ...
            f = _base.Future()
            ...
            self._adjust_thread_count()  # 在submit中执行了_adjust_thread_count()
            return f  # 返回Future的对象
                    
    def _adjust_thread_count(self):
        num_threads = len(self._threads)
        if num_threads < self._max_workers:
            thread_name = '%s_%d' % (self._thread_name_prefix or self,
                                     num_threads)
            # 创建一个线程对象t
            t = threading.Thread(name=thread_name, target=_worker,
                                 args=(weakref.ref(self, weakref_cb),
                                       self._work_queue,
                                       self._initializer,
                                       self._initargs))
            t.daemon = True
            t.start()  # 启动线程
            ...

shutdown

注意,在_adjust_thread_count()中并没有执行join()方法,也就是说,子线程执行的同时,主线程也会向下执行。ThreadPoolExecutor的成员函数中有一个shutdown(),当参数wait=True时可以用来阻塞主线程,其本质是调用了每个子线程的join()

# shutdown源码
class ThreadPoolExecutor(_base.Executor):
    ...
    def shutdown(self, wait=True):
        ...
        if wait:
            for t in self._threads:  #  self._threads是用来存放子线程的集合
                t.join()  # 调用每个子线程的join方法

在前面代码的基础上做些小修改,shutdown()的作用立竿见影:

...
if __name__ == "__main__":
    executor = ThreadPoolExecutor(max_workers=2)
    task1th = executor.submit(print_hello)
    task2ed = executor.submit(print_world)
    # 使用shutdown方法
    executor.shutdown()
    print("zty")

# 输出:
w0 world
h0 hello
...
w9 world
h9 hello
zty  # zty最后输出
...
if __name__ == "__main__":
    executor = ThreadPoolExecutor(max_workers=2)
    task1th = executor.submit(print_hello)
    task2ed = executor.submit(print_world)
    # 不使用shutdown方法
    print("zty")

# 输出:
zty  # zty最先输出
w0 world
h0 hello
...
w9 world
h9 hello

map

提交线程任务除了submit()外,还提供了map()方法。此map与Python内置的map在使用上相似,它可以批量启动相同函数的不同线程。

def print_num(num):
    print(num)

if __name__ == "__main__":
    executor = ThreadPoolExecutor(max_workers=3)
    executor.map(print_num, [1, 2, 3])

# 输出:
1
2
3

上述代码在运行过程中启动了三个子线程,三个子线程又各自只打印了一个数字。事实上,在Python内部,map()的实现也是基于submit()。

# map源码
class Executor(object):
    def map(self, fn, *iterables, timeout=None, chunksize=1):
        ...
        fs = [self.submit(fn, *args) for args in zip(*iterables)]
        ...

map()函数接收参数的设计我认为比较巧妙,在zip()的帮忙下,你可以传递多个可迭代对象。

def print_num(num, alpha):
    print(num, alpha)

if __name__ == "__main__":
    executor = ThreadPoolExecutor(max_workers=3)
    executor.map(print_num, [1, 2, 3], ["a", "b", "c"])

# 输出:
1 a  # 注意成对输出
2 b
3 c

最后,map会返回一个生成器,里面放着线程的运行结果。

# map源码
class Executor(object):
    def map(self, fn, *iterables, timeout=None, chunksize=1):
        ...
        def result_iterator():
            try:
                # reverse to keep finishing order
                fs.reverse()
                while fs:
                    # Careful not to keep a reference to the popped future
                    if timeout is None:
                        yield fs.pop().result()  # 调用future的result()方法拿到结果
                    else:
                        yield fs.pop().result(end_time - time.time())
                        # 调用future的result()方法拿到结果
            ...
        return result_iterator()  # 返回这个生成器

使用示例:

def print_num(n1, n2):
    return n1+n2

if __name__ == "__main__":
    executor = ThreadPoolExecutor(max_workers=3)
    results = executor.map(print_num, [1, 2, 3], [11, 12, 13])
    for result in results:
        print(result)
# 输出:
12
14
16

构造函数接收参数

回归到ThreadPoolExecutor类,查看它的__init__()方法:

# __init__源码
class ThreadPoolExecutor(_base.Executor):
    def __init__(self, max_workers=None, thread_name_prefix='',
                 initializer=None, initargs=()):
        ....

其接收四个参数,意义分别为:

  • max_workers 表示允许的最大线程数量
  • thread_name_prefix 表示线程名字的前缀(thread_0、thread_1、thread_2中的thread)
  • initializer 表示子线程执行前需要执行的函数
  • initargs 表示initializer函数接收的参数
def print_num(num):
    print(num)

def print_two():
    print("two")

if __name__ == "__main__":
    executor = ThreadPoolExecutor(max_workers=1, 
                                  initializer=print_num, 
                                  initargs=("one",))
    task1th = executor.submit(print_two)

# 输出:
one  # 在打印two之前,先打印了one
two

有以下几点需要注意:

  • initializer代表的函数在调用submit()后被执行,而不是executor初始化时;
  • initializer代表的函数出现异常,后面的线程将不再执行。

整个逻辑被写在了_work()函数中:

# concurrent/futures/thread.py下_work()源码
def _worker(executor_reference, work_queue, initializer, initargs):
    if initializer is not None:
        try:
            initializer(*initargs)  # 执行初始化函数
        except BaseException:
            ...
            return
    try:
        while True:
            work_item = work_queue.get(block=True)  # 从队列中取出线程
            if work_item is not None:
                work_item.run()  # 执行线程
                # Delete references to object. See issue16284
                del work_item
                continue
            ...

Future

submit()会返回Future对象。前面已经贴过submit的源码,这里不做重复。Future对象包含以下几个重要方法:

  • add_done_callback() 接收一个函数名,当线程执行完后调用传入函数;
  • result() 用于获取线程执行的结果;
  • exception() 用于获取线程执行过程中存在的异常。

add_done_callback()result()的使用示例:

def print_and_return_num():
    print(512)
    return 512

def print_three(*args):
    print("结束")

if __name__ == "__main__":
    executor = ThreadPoolExecutor(max_workers=1)
    task = executor.submit(print_and_return_num)
    
    print("程序执行结果:", task.result())  # 打印线程的执行结果
    task.add_done_callback(print_three)   # 线程结束后执行print_three

# 输出:
512
程序执行结果: 512
结束

我们再来看看add_done_callback()的源码:

# add_done_callback源码
class Future(object):
    ...
    def add_done_callback(self, fn):
        ...
        with self._condition:
            if self._state not in [CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED]:
                self._done_callbacks.append(fn)
                return
        fn(self)  # fn被调用

也就是说,add_done_callback接收的函数fn,必须接收一个参数,Python会在调用它时,默认把对象(self)传进去。


同时,Future还提供了获取线程状态的三个方法:

  • cancelled() 当线程状态是CANCELLED或者CANCELLED_AND_NOTIFIED,返回True;
  • running() 当线程状态是RUNNING,返回True;
  • done() 当线程状态是CANCELLEDCANCELLED_AND_NOTIFIED或者FINISHED,返回True。

也还有一个我暂时不知道有什么的方法:cancel(),官方文档对其释义:

Attempt to cancel the call. If the call is currently being executed and cannot be cancelled then the method will return False, otherwise the call will be cancelled and the method will return True.
当程序正在执行时(RUNNING)或者处于不可以被取消的状态时(FINISHED),返回False。否则取消调用,并且返回True。


本文由 Guan 创作,采用 知识共享署名 3.0,可自由转载、引用,但需署名作者且注明文章出处。

还不快抢沙发

添加新评论