笔记:High Performance Python第9章

这一章讲的是并行计算,占了60多页。看完后得把这一大坨东西理理顺消化一下。

1.用蒙特卡洛模拟估算pi。

逻辑很简单:在坐标系的单位正方形里投针,计算落在1/4单位圆中(x^2+y^2<=1)的针的比例,然后乘以4。顺序执行,投掷100,000,000次花费的时间大约是120秒。

1.1用多进程加速

每次投掷都是彼此独立的,所以直接把工作量差成多份(比如8份),交给多个进程执行即可。可以想见,这里运行的函数不需要与程序中其他的部分共享状态,而且进程间只需要传递一小部分数据就能完成大量运算。
函数estimate_nbr_points_in_quarter_circle反复投针然后计算落在1/4单位圆内的针的个数:

def estimate_points(nbr_estimates):
    in_unit_circle = 0
    for step in range(int(nbr_estimates)):
        x = random.uniform(0, 1)
        y = random.uniform(0, 1)
        is_in_unit_circle = x * x + y * y <= 1.0
        in_unit_circle += is_in_unit_circle
    return in_unit_circle

并行部分:

from multiprocessing import Pool
...
samples_total = 1e8
N = 8
pool = Pool(processes=N)
samples = samples_total / N
trials = [samples] * N

t1 = time.time()
nbr_in_unit_circles = pool.map(estimate_points, trials)
pi_estimate = sum(nbr_in_unit_circles) * 4 / samples_total

print("Estimate pi", pi_estimate)
print("Delta:", time.time() - t1)

这是一个写起来相当简单的并行程序。确定要使用的进程个数(一般是cpu的个数),在trials列表里设置好每个进程的参数,然后往pool.map里一丢,就好像在使用普通的map函数一样。运行它大约花了19秒。

(插一句,书中这段代码读起来十分辛苦,因为变量名实在过长,到处都是nbr_trials_in_quarter_unit_circle,nbr_trials_per_process,nbr_samples_in_total这种三四个单词拼起来的变量名,把很简单的东西搞得大段大段的,绕得人头晕。这个故事告诉我们代码要简练,信息密度要高。)

Effective Python第41条指出,multiprocessing的开销是比较高的,因为主进程和子进程之间必须进行序列化和反序列化操作。具体而言,multiprocessing做了这些事:

  1. trials列表中的每一项数据都传给map
  2. pickle模块对数据进行序列化,将其变成二进制形式。
  3. 通过local socket将序列化之后的数据从主解释器所在的进程发送到子解释器所在的进程。
  4. 在子进程中用pickle对二进制数据进行反序列化操作,将其还原为python对象。
  5. 引入包含estimate_points函数的python模块。
  6. 各条子进程平行地针对各自的输入数据运行estimate_points函数。
  7. 对运行结果进行序列化操作,将其转变为字节。
  8. 将这些字节通过socket复制到主进程中。
  9. 主进程对这些字节执行反序列化操作,还原为python对象。
  10. 把每条子进程求出的结算结果合并到一份列表中,返回给调用者。

1.2并行系统中的随机数

在并行计算中得好好想想会不会得到重复的或者相关的随机序列。如果是使用python自带的random模块,multiprocessing会自己在每次fork过程中重新设置随机数生成器的种子。但如果是用numpy,就得亲自重新设置,不然random就会返回相同的序列。

使用numpy:

import numpy as np

def estimate_points(samples):
    np.random.seed()
    xs = np.random.uniform(0, 1, samples)
    ys = np.random.uniform(0, 1, samples)
    is_in_quc = (xs * xs + ys * ys) <= 1.0 
    in_quc = np.sum(is_in_quc)
    return in_quc

使用numpy使得运行时间缩减到了1.13秒。numpy可太强了(或者说CPython可太菜了)。

2.查找素数

在一个很大的范围内查找素数和估算pi是不同的,因为工作量的大小和查找范围的上下限的大小有关(检查[10, 100]和[10000, 100000]的工作量肯定是不同的),检查每个数字的复杂程度也不一样(谁知道检查到第几个素因数的时候它就被整除了呢?偶数检查起来是最简单的,素数是最难的)。这个问题是embarassingly parallel的(不会翻译...),即没有需要共享的状态。关键在于如何平衡进程之间的工作量(load balancing),将复杂度各异的任务分配给有限的计算资源。
当我们把计算任务分配给进程池的时候,我们可以控制给每个进程分配多少工作量,把工作量分成块(chunk),一旦有cpu空闲下来了就给它分配工作。块越大,通信开销越小;块越小,控制越精细。(块大小为10就是指一个进程一次检查10个数字)。作者们给出了一张”块数-运行时间“的关系图,用来说明“块数是cpu个数的倍数时运行时间最短”这个简单的道理(不然在执行最后一轮计算时会有cpu空着)。

我们可以使用队列来向一组工作进程提供任务并收集结果:

ALL_DONE = b"ALL_DONE"
WORKER_FINISHED = b"WORKER_FINISHED"

def check_prime(possibles, definites):
    while True:
        n = possibles.get()
        if n == ALL_DONE:
            definites.put(WORKER_FINISHED)
            break
        else:
            if n % 2 == 0:
                continue
            for i in range(3, int(math.sqrt(n)) + 1 , 2):
                if n % i == 0:
                    break
            else:
                definites.put(n)

possiblesdefinites为两个队列,用于结果的输入和输出。我们设置了两个标志位(flag),ALL_DONE作为终止循环的sentinel由父进程在将数字塞进possibles后提供,用于告诉子进程已经已经没有别的工作了。子进程收到ALL_DONE后,向definites输出WORKER_FINISHED,告诉父进程自己已经收到sentinel,然后终止从possibles队列获取输入。

创建输入输出队列和8个进程,向possibles队列中添加数字,并在最后加入8个ALL_DONEsentinel:

if __name__ == '__main__':
    primes = []
    possibles = Queue()
    definites = Queue()

    N = 8
    pool = Pool(processes=N)
    processes = []
    for _ in range(N):
        p = Process(target=check_prime,args=(possibles, definites))
        processes.append(p)
        p.start()
    
    t1 = time.time()
    
    number_range = range(10000000000, 10000100000)
    for possible in number_range:
        possibles.put(possible)
    print("ALL JOBS ADDED TO THE QUEUE")

    # add poison pills to stop the remote workers
    for n in range(N):
        possibles.put(ALL_DONE)
    print("NOW WAITING FOR RESULTS...")
    ...

循环地从definites队列中获取结果(当然,结果不是顺序的),得到8个WORKER_FINISHED后停止循环:

    ...
    processors_finished = 0
    while True:
        new_result = definites.get()
        if new_result == WORKER_FINISHED:
            processors_finished += 1
            print("{} WORKER(S) FINISHED".format(processors_finished))
            if processors_finished == N:
                break
        else:
            primes.append(new_result)
    assert processors_finished == N

    print("Took:", time.time() - t1)
    print(len(primes), primes[:10], primes[-10:])
    

程序执行花了7秒多,而顺序执行需要20秒左右。但由于创建队列需要序列化和同步,多进程的执行速度并不一定会比顺序执行更快。在原书中就是如此,甚至当作者把所有偶数从输入队列中剔除后多进程执行还是比顺序执行慢,说明多进程执行的程序有很大一部分时间是花费在通信开销上的。

3.验证素数

与第2节的“寻找一个范围内所有的素数”不同,现在我们来解决如何快速判断一个特别大的数(比如一个18位数)是否为素数的问题——由多个cpu合作完成。这是一个需要进程间通信或共享状态的问题。

3.1简单的进程池

与前两个例子相似,我们把要检查的数字的可能的因子分为多组,传递给多个子进程进行检查。当某个子进程中的因子整除了这个数,子进程就返回False——但这不会让别的子进程停下来(所以是一个简单的版本)。这或许会让别的子进程做无用功,但也省去了检查共享状态的通信开销。
把因子分组:

def create_range(from_i, to_i, N):
    piece_length = int((to_i - from_i) / N)
    lrs = [from_i] 
          + [(i + 1) if (i % 2 == 0) else i 
             for i in range(from_i, to_i, piece_length)[1:]]
    if len(lrs) > N:
        lrs.pop()
    assert len(lrs) == N
    ranges = list(zip(lrs, lrs[1:])) + [(lrs[-1], to_i)]
    return ranges

e.g. create_range(1000, 100000, 4)的返回值是[(1000, 25751), (25751, 50501), (50501, 75251), (75251, 100000)]。

import time
import math
from multiprocessing import Pool


def check_prime_in_range(args):
    n, from_i, to_i = args #似乎只能通过传入元组然后再拆包的方式达到传入多参数的效果
    from_i, to_i = ranges
    if n % 2 == 0:
        return False
    for i in range(from_i, to_i, 2):
        if n % i == 0:
            return False
    return True


def check_prime(n, pool, N):
    from_i = 3
    to_i = int(math.sqrt(n)) + 1
    ranges = create_range(from_i, to_i, N)
    args = [(n, from_i, to_i) for from_i, to_i in ranges]`
    results = pool.map(check_prime_in_range, args)
    if False in results:
        return False
    return True


if __name__ == "__main__":
    N = 8
    pool = Pool(processes=N)
    prime18 = 100109100129100151
    t1 = time.time()
    print("%d: %s" %(prime18, check_prime(prime18, pool, N)))
    print('Took:', time.time() - t1)

大约用了10s。

3.2稍微没那么简单的进程池

由于额外开销的原因,对于小一点的数字,多进程的方法可能还没有顺序查找的方法好。而且,如果已经找到了一个很小的因数,程序也不会马上停下来。当然,在因数时我们可以立即在进程间通信,但这会产生大量的额外通信开销,因为大多数数字都会有一个较小的因数。于是我们采用混合策略:先顺序查找较小的因数,然后再将剩余的工作分派给多个进程。这是一种避免多进程开销的常见做法。

def check_prime(n, pool, N):
    from_i = 3
    to_i= 21
    args = (n, from_i, to_i)
    if not check_prime_in_range(args):
        return False
        
    from_i = to_i
    to_i = int(math.sqrt(n)) + 1
    ranges = create_range(from_i, to_i, N)
    args = [(n, from_i, to_i) for from_i, to_i in ranges]
    results = pool.map(check_prime_in_range, args)
    if False in results:
        return False
    return True

3.3使用multiprocessing.Manager()作为标志位

直接上代码吧。可以看到这里用Manager创建了一个符号位。读取这个符号位并不需要自己做什么上锁之类的操作,就像在检查一个全局变量一样方便(不过还是要作为参数传入函数中的)。为了节约通信开销,让每个进程每检查1000个数检查一次符号位。如果进程检查到了FLAG_SET或者找到了因数就停下来。

import time
import math
from multiprocessing import Pool, Manager


SERIAL_CHECK_CUTOFF = 21
CHECK_EVERY = 1000
FLAG_CLEAR = b'0'
FLAG_SET = b'1'


def create_range(from_i, to_i, N):
    piece_length = int((to_i - from_i) / N)
    lrs = [from_i] + [(i + 1) if (i % 2 == 0) else i for i in range(from_i, to_i, piece_length)[1:]]
    if len(lrs) > N:
        lrs.pop()
        assert len(lrs) == N
    ranges = list(zip(lrs, lrs[1:])) + [(lrs[-1], to_i)]
    return ranges

def check_prime_in_range(args):
    n, from_i, to_i, value = args
    if n % 2 == 0:
        return False
    check_every = CHECK_EVERY
    for i in range(from_i, to_i, 2):
        check_every -= 1
        if not check_every:
            if value.value == FLAG_SET:
                return False
            check_every = CHECK_EVERY

        if n % i == 0:
            value.value = FLAG_SET
            return False
    return True

def check_prime(n, pool, N, value):
    from_i = 3
    to_i= SERIAL_CHECK_CUTOFF
    value.value = FLAG_CLEAR  # 要记得先设置标志位的值
    args = (n, from_i, to_i, value)
    if not check_prime_in_range(args):
        return False

    from_i = to_i
    to_i = int(math.sqrt(n)) + 1
    ranges = create_range(from_i, to_i, N)
    args = [(n, from_i, to_i, value) for from_i, to_i in ranges]
    results = pool.map(check_prime_in_range, args)
    if False in results:
        return False
    return True


if __name__ == "__main__":
    N = 8
    manager = Manager()
    value = manager.Value(b'c', FLAG_CLEAR) # 创建一个一字节(一字符)大小的符号标志位
    pool = Pool(processes=N)
    prime18 = 100109100129100151
    non_prime = 100109100129101027
    t1 = time.time()
    print("%d: %s" %(non_prime, check_prime(non_prime, pool, N, value)))
    print('Took:', time.time()-t1)

你可能感兴趣的