Python 优化:使用 Numba 加速嵌套循环计算

碧海醫心
发布: 2025-10-17 13:35:00
原创
155人浏览过

python 优化:使用 numba 加速嵌套循环计算

本文介绍了如何使用 Numba 库优化 Python 中包含嵌套循环的计算密集型函数。通过 Numba 的即时编译(JIT)技术,可以将 Python 代码转换为机器码,从而显著提高程序的执行速度。本文提供了详细的代码示例和性能比较,展示了 Numba 在加速嵌套循环计算方面的强大能力,并探讨了并行化的进一步优化。

在 Python 中,当涉及到需要大量计算的嵌套循环时,程序的执行速度往往会成为瓶颈。传统的 Python 解释器在执行循环时效率较低,尤其是在处理大型数据集时。为了解决这个问题,可以使用 Numba 库来加速 Python 代码的执行。Numba 是一个开源的即时编译器,它可以将 Python 代码转换为优化的机器码,从而显著提高程序的性能。

Numba 简介

Numba 通过装饰器(decorators)的方式来指定需要编译的函数。当 Numba 遇到被装饰的函数时,它会将该函数编译为机器码,并在后续的调用中使用编译后的版本。这种即时编译的方式可以避免 Python 解释器的开销,从而提高程序的执行速度。

使用 Numba 加速嵌套循环

以下是一个使用 Numba 加速嵌套循环的示例。假设我们有一个函数 U_p_law,它包含两个嵌套循环,用于计算某种概率分布。

立即学习Python免费学习笔记(深入)”;

import numpy as np
from timeit import timeit
from numba import njit, prange

P_mean = 1500
P_std = 100
Q_mean = 1500
Q_std = 100
W = 1  # Number of matches won by P
L = 0  # Number of matches lost by P
L_P = np.exp(-0.5 * ((np.arange(0, 3501, 10) - P_mean) / P_std) ** 2) / (
    P_std * np.sqrt(2 * np.pi)
)
L_Q = np.exp(-0.5 * ((np.arange(0, 3501, 10) - Q_mean) / Q_std) ** 2) / (
    Q_std * np.sqrt(2 * np.pi)
)


def probability_of_loss(x):
    return 1 / (1 + np.exp(x / 67))


def U_p_law(W, L, L_P, L_Q):
    omega = np.arange(0, 3501, 10)

    U_p = np.zeros_like(omega, dtype=float)

    for p_idx, p in enumerate(omega):
        for q_idx, q in enumerate(omega):
            U_p[p_idx] += (
                probability_of_loss(q - p) ** W
                * probability_of_loss(p - q) ** L
                * L_Q[q_idx]
                * L_P[p_idx]
            )

    normalization_factor = np.sum(U_p)
    U_p /= normalization_factor

    return omega, U_p
登录后复制

为了使用 Numba 加速这个函数,我们只需要添加 @njit 装饰器即可。

@njit
def probability_of_loss_numba(x):
    return 1 / (1 + np.exp(x / 67))


@njit
def U_p_law_numba(W, L, L_P, L_Q):
    omega = np.arange(0, 3501, 10, dtype=np.float64)

    U_p = np.zeros_like(omega)

    for p_idx, p in enumerate(omega):
        for q_idx, q in enumerate(omega):
            U_p[p_idx] += (
                probability_of_loss_numba(q - p) ** W
                * probability_of_loss_numba(p - q) ** L
                * L_Q[q_idx]
                * L_P[p_idx]
            )

    normalization_factor = np.sum(U_p)
    U_p /= normalization_factor

    return omega, U_p
登录后复制

@njit 装饰器告诉 Numba 将 U_p_law_numba 函数编译为机器码。需要注意的是,为了获得最佳性能,建议在 Numba 函数中使用 NumPy 数组,并指定数组的数据类型。

算家云
算家云

高效、便捷的人工智能算力服务平台

算家云 37
查看详情 算家云

并行化 Numba 函数

对于包含大量计算的嵌套循环,还可以通过并行化来进一步提高程序的性能。Numba 提供了 prange 函数,它可以将循环并行化,从而利用多核 CPU 的优势。

@njit(parallel=True)
def U_p_law_numba_parallel(W, L, L_P, L_Q):
    omega = np.arange(0, 3501, 10, dtype=np.float64)

    U_p = np.zeros_like(omega)

    for p_idx in prange(len(omega)):
        p = omega[p_idx]
        for q_idx in prange(len(omega)):
            q = omega[q_idx]
            U_p[p_idx] += (
                probability_of_loss_numba(q - p) ** W
                * probability_of_loss_numba(p - q) ** L
                * L_Q[q_idx]
                * L_P[p_idx]
            )

    normalization_factor = np.sum(U_p)
    U_p /= normalization_factor

    return omega, U_p
登录后复制

要并行化 Numba 函数,需要添加 parallel=True 参数到 @njit 装饰器中,并将外层循环替换为 prange。需要注意的是,并行化可能会引入额外的开销,因此只有在循环的计算量足够大时才能获得性能提升。

性能比较

以下是使用 Numba 加速后的性能比较结果。

omega_1, U_p_1 = U_p_law(W, L, L_P, L_Q)
omega_2, U_p_2 = U_p_law_numba(W, L, L_P, L_Q)
omega_3, U_p_3 = U_p_law_numba_parallel(W, L, L_P, L_Q)

assert np.allclose(omega_1, omega_2)
assert np.allclose(omega_1, omega_3)
assert np.allclose(U_p_1, U_p_2)
assert np.allclose(U_p_1, U_p_3)

t1 = timeit("U_p_law(W, L, L_P, L_Q)", number=10, globals=globals())
t2 = timeit("U_p_law_numba(W, L, L_P, L_Q)", number=10, globals=globals())
t3 = timeit("U_p_law_numba_parallel(W, L, L_P, L_Q)", number=10, globals=globals())

print("10 calls using vanilla Python     :", t1)
print("10 calls using Numba              :", t2)
print("10 calls using Numba (+ parallel) :", t3)
登录后复制

在我的机器上 (AMD 5700x),输出结果如下:

10 calls using vanilla Python     : 2.4276352748274803
10 calls using Numba              : 0.013957140035927296
10 calls using Numba (+ parallel) : 0.003793451003730297
登录后复制

可以看到,使用 Numba JIT 可以获得约 170 倍的加速,而使用多线程 Numba JIT 可以获得约 640 倍的加速。

注意事项

  • Numba 对 Python 代码有一定的限制,例如不支持所有的 Python 特性。在使用 Numba 之前,需要确保代码满足 Numba 的要求。
  • Numba 的编译过程需要一定的时间,因此在第一次调用 Numba 函数时可能会比较慢。但是,在后续的调用中,Numba 会使用编译后的版本,从而提高程序的执行速度。
  • 并行化可能会引入额外的开销,因此只有在循环的计算量足够大时才能获得性能提升。

总结

Numba 是一个强大的 Python 优化工具,它可以显著提高包含嵌套循环的计算密集型函数的执行速度。通过使用 Numba 的即时编译技术和并行化功能,可以充分利用 CPU 的性能,从而加速 Python 程序的执行。在处理大型数据集和复杂的计算任务时,Numba 可以成为提高程序性能的关键。

以上就是Python 优化:使用 Numba 加速嵌套循环计算的详细内容,更多请关注php中文网其它相关文章!

最佳 Windows 性能的顶级免费优化软件
最佳 Windows 性能的顶级免费优化软件

每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。

下载
来源:php中文网
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
最新问题
开源免费商场系统广告
热门教程
更多>
最新下载
更多>
网站特效
网站源码
网站素材
前端模板
关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新 English
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送
PHP中文网APP
随时随地碎片化学习

Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号