
本文介绍了如何使用 Numba 库优化 Python 中包含嵌套循环的计算密集型函数。通过 Numba 的即时编译(JIT)技术,可以将 Python 代码转换为机器码,从而显著提高程序的执行速度。本文提供了详细的代码示例和性能比较,展示了 Numba 在加速嵌套循环计算方面的强大能力,并探讨了并行化的进一步优化。
在 Python 中,当涉及到需要大量计算的嵌套循环时,程序的执行速度往往会成为瓶颈。传统的 Python 解释器在执行循环时效率较低,尤其是在处理大型数据集时。为了解决这个问题,可以使用 Numba 库来加速 Python 代码的执行。Numba 是一个开源的即时编译器,它可以将 Python 代码转换为优化的机器码,从而显著提高程序的性能。
Numba 通过装饰器(decorators)的方式来指定需要编译的函数。当 Numba 遇到被装饰的函数时,它会将该函数编译为机器码,并在后续的调用中使用编译后的版本。这种即时编译的方式可以避免 Python 解释器的开销,从而提高程序的执行速度。
以下是一个使用 Numba 加速嵌套循环的示例。假设我们有一个函数 U_p_law,它包含两个嵌套循环,用于计算某种概率分布。
立即学习“Python免费学习笔记(深入)”;
import numpy as np
from timeit import timeit
from numba import njit, prange
P_mean = 1500
P_std = 100
Q_mean = 1500
Q_std = 100
W = 1 # Number of matches won by P
L = 0 # Number of matches lost by P
L_P = np.exp(-0.5 * ((np.arange(0, 3501, 10) - P_mean) / P_std) ** 2) / (
P_std * np.sqrt(2 * np.pi)
)
L_Q = np.exp(-0.5 * ((np.arange(0, 3501, 10) - Q_mean) / Q_std) ** 2) / (
Q_std * np.sqrt(2 * np.pi)
)
def probability_of_loss(x):
return 1 / (1 + np.exp(x / 67))
def U_p_law(W, L, L_P, L_Q):
omega = np.arange(0, 3501, 10)
U_p = np.zeros_like(omega, dtype=float)
for p_idx, p in enumerate(omega):
for q_idx, q in enumerate(omega):
U_p[p_idx] += (
probability_of_loss(q - p) ** W
* probability_of_loss(p - q) ** L
* L_Q[q_idx]
* L_P[p_idx]
)
normalization_factor = np.sum(U_p)
U_p /= normalization_factor
return omega, U_p为了使用 Numba 加速这个函数,我们只需要添加 @njit 装饰器即可。
@njit
def probability_of_loss_numba(x):
return 1 / (1 + np.exp(x / 67))
@njit
def U_p_law_numba(W, L, L_P, L_Q):
omega = np.arange(0, 3501, 10, dtype=np.float64)
U_p = np.zeros_like(omega)
for p_idx, p in enumerate(omega):
for q_idx, q in enumerate(omega):
U_p[p_idx] += (
probability_of_loss_numba(q - p) ** W
* probability_of_loss_numba(p - q) ** L
* L_Q[q_idx]
* L_P[p_idx]
)
normalization_factor = np.sum(U_p)
U_p /= normalization_factor
return omega, U_p@njit 装饰器告诉 Numba 将 U_p_law_numba 函数编译为机器码。需要注意的是,为了获得最佳性能,建议在 Numba 函数中使用 NumPy 数组,并指定数组的数据类型。
对于包含大量计算的嵌套循环,还可以通过并行化来进一步提高程序的性能。Numba 提供了 prange 函数,它可以将循环并行化,从而利用多核 CPU 的优势。
@njit(parallel=True)
def U_p_law_numba_parallel(W, L, L_P, L_Q):
omega = np.arange(0, 3501, 10, dtype=np.float64)
U_p = np.zeros_like(omega)
for p_idx in prange(len(omega)):
p = omega[p_idx]
for q_idx in prange(len(omega)):
q = omega[q_idx]
U_p[p_idx] += (
probability_of_loss_numba(q - p) ** W
* probability_of_loss_numba(p - q) ** L
* L_Q[q_idx]
* L_P[p_idx]
)
normalization_factor = np.sum(U_p)
U_p /= normalization_factor
return omega, U_p要并行化 Numba 函数,需要添加 parallel=True 参数到 @njit 装饰器中,并将外层循环替换为 prange。需要注意的是,并行化可能会引入额外的开销,因此只有在循环的计算量足够大时才能获得性能提升。
以下是使用 Numba 加速后的性能比较结果。
omega_1, U_p_1 = U_p_law(W, L, L_P, L_Q)
omega_2, U_p_2 = U_p_law_numba(W, L, L_P, L_Q)
omega_3, U_p_3 = U_p_law_numba_parallel(W, L, L_P, L_Q)
assert np.allclose(omega_1, omega_2)
assert np.allclose(omega_1, omega_3)
assert np.allclose(U_p_1, U_p_2)
assert np.allclose(U_p_1, U_p_3)
t1 = timeit("U_p_law(W, L, L_P, L_Q)", number=10, globals=globals())
t2 = timeit("U_p_law_numba(W, L, L_P, L_Q)", number=10, globals=globals())
t3 = timeit("U_p_law_numba_parallel(W, L, L_P, L_Q)", number=10, globals=globals())
print("10 calls using vanilla Python :", t1)
print("10 calls using Numba :", t2)
print("10 calls using Numba (+ parallel) :", t3)在我的机器上 (AMD 5700x),输出结果如下:
10 calls using vanilla Python : 2.4276352748274803 10 calls using Numba : 0.013957140035927296 10 calls using Numba (+ parallel) : 0.003793451003730297
可以看到,使用 Numba JIT 可以获得约 170 倍的加速,而使用多线程 Numba JIT 可以获得约 640 倍的加速。
Numba 是一个强大的 Python 优化工具,它可以显著提高包含嵌套循环的计算密集型函数的执行速度。通过使用 Numba 的即时编译技术和并行化功能,可以充分利用 CPU 的性能,从而加速 Python 程序的执行。在处理大型数据集和复杂的计算任务时,Numba 可以成为提高程序性能的关键。
以上就是Python 优化:使用 Numba 加速嵌套循环计算的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号