
在使用numpy进行数值计算时,我们有时会遇到看似相同但实际存在微小差异的结果。以下面的例子为例,我们尝试计算两个数组a和b之间l2范数平方的负一半。
首先,定义两个NumPy数组:
import numpy as np
a = np.array([[ 0, 1, 10, 2, 5]])
b = np.array([[ 0, 1, 18, 15, 5],
[13, 9, 23, 3, 22],
[ 2, 10, 17, 4, 8]])接下来,我们使用两种方法计算所需的结果:
方法一:使用 np.linalg.norm
这种方法利用 np.linalg.norm 函数来计算L2范数,然后进行平方。
m1 = -np.linalg.norm(a[:, np.newaxis, :] - b[np.newaxis, :, :], axis=-1) ** 2 / 2
方法二:手动展开 L2 范数平方
这种方法直接根据L2范数平方的定义,通过求差、平方和再求和的方式计算。
m2 = -np.sum(np.square(a[:, np.newaxis, :] - b[np.newaxis, :, :]), axis=-1) / 2
当我们打印这两个结果时,它们在视觉上是相同的:
print(m1) # 输出: [[-116.5 -346. -73.5]] print(m2) # 输出: [[-116.5 -346. -73.5]]
然而,当我们使用 np.array_equal 进行精确比较时,结果却出乎意料:
print(np.array_equal(m1, m2)) # 输出: False
这表明 m1 和 m2 尽管看起来一样,但底层数值并不完全相等。更有趣的是,如果我们创建一个字面量数组来检查相等性:
sanity_check = np.array([[-116.5, -346. , -73.5]]) print(np.array_equal(sanity_check, m1)) # 输出: False print(np.array_equal(sanity_check, m2)) # 输出: True
这进一步确认了 m1 是“异常”的那个。
np.linalg.norm 方法与手动计算方法之间的微小差异,主要源于浮点数运算的本质以及 np.linalg.norm 函数的内部实现。
L2 范数的定义与 np.linalg.norm 的实现
L2 范数(欧几里得范数)的定义是向量各元素平方和的平方根。即 ||x||_2 = sqrt(sum(x_i^2))。 当我们需要计算L2范数的平方时,理论上 ||x||_2^2 = sum(x_i^2)。 np.linalg.norm(..., ord=2) 在内部会执行 sqrt(sum(x_i^2)) 的操作。因此,np.linalg.norm(..., ord=2) ** 2 实际上是 (sqrt(sum(x_i^2))) ** 2。
浮点数运算的精度问题
在计算机中,浮点数(如Python中的float,NumPy中的np.float64)的表示是有限精度的。这意味着某些实数无法被精确表示,只能近似。当进行数学运算,尤其是涉及平方根等操作时,这种近似性可能导致微小的误差累积。
考虑一个简单的例子:
print(np.sqrt(8**2 + 13**2)**2) # 输出: 232.99999999999997 print(8**2 + 13**2) # 输出: 233
在这个例子中,8**2 + 13**2 结果是精确的整数 233。然而,np.sqrt(233) 会产生一个浮点数近似值,即使这个近似值再被平方,也可能无法完全恢复到原始的整数 233,而是产生一个非常接近但略有偏差的浮点数,例如 232.99999999999997。
回到我们的 m1 和 m2,m1 的计算路径是: m1 = - (np.sqrt(np.sum(np.square(diff)))) ** 2 / 2 而 m2 的计算路径是: m2 = - np.sum(np.square(diff)) / 2
m1 在中间多了一步 sqrt 操作,正是这一步引入了浮点数精度误差。为了验证这一点,我们可以查看 m1 和 m2 的原始数值:
print(m1.tolist()) # 输出: [[-116.49999999999999, -346.0, -73.5]] print(m2.tolist()) # 输出: [[-116.5, -346.0, -73.5]]
可以看到,m1 的第一个元素 -116.49999999999999 与 m2 的 -116.5 存在微小的差异。
尽管 m1 和 m2 存在实际的数值差异,但 print() 函数默认情况下却显示它们是相同的。这是因为NumPy的打印选项(由 np.set_printoptions 控制)会根据设定的精度对浮点数进行四舍五入或截断显示。
我们可以通过 np.get_printoptions() 查看当前的打印设置:
print(np.get_printoptions())
# 典型输出示例: {'edgeitems': 3, 'threshold': 1000, 'floatmode': 'maxprec', 'precision': 3, 'suppress': False, 'linewidth': 75, 'nanstr': 'nan', 'infstr': 'inf', 'sign': '-', 'formatter': None, 'legacy': False}其中,'precision': 3 表示默认显示小数点后3位。由于 m1 和 m2 的差异发生在更低的位数上,因此在默认的显示精度下,这些差异被隐藏了。
如果我们临时提高打印精度,就可以看到实际的差异:
with np.printoptions(precision=17): # 设置更高精度
print(m1)
# 输出: [[-116.49999999999998607 -346.00000000000000000 -73.50000000000000000]]
print(m2)
# 输出: [[-116.50000000000000000 -346.00000000000000000 -73.50000000000000000]]通过将 precision 设置为更高的值(例如17),我们能够清晰地看到 m1 和 m2 之间微小的数值差异。
处理浮点数精度问题是数值计算中的常见挑战。以下是一些建议和最佳实践:
由于浮点数精度问题,直接使用 == 运算符或 np.array_equal() 来比较浮点数通常是不可靠的。即使两个数在数学上应该相等,也可能因为微小的计算误差而导致它们不相等。
在比较浮点数时,应使用带有容忍度(tolerance)的比较方法。NumPy 提供了 np.allclose() 函数,它允许指定一个绝对容忍度(atol)和一个相对容忍度(rtol),只有当两个数组的对应元素之差在这些容忍度之内时,才认为它们相等。
# 检查 m1 和 m2 是否在默认容忍度下接近 print(np.allclose(m1, m2)) # 输出: True (通常默认容忍度足以覆盖这种微小差异) # 可以手动指定容忍度 print(np.allclose(m1, m2, rtol=1e-05, atol=1e-08)) # 输出: True
np.allclose() 是处理浮点数比较的标准方法。
如果你的目标是计算L2范数的平方,而不是L2范数本身,那么直接使用 np.sum(np.square(...)) 是更优的选择。这种方法避免了中间的 np.sqrt() 操作,从而减少了引入浮点数精度误差的可能性。
# 推荐计算 L2 范数平方的方法 squared_l2_norm = np.sum(np.square(a[:, np.newaxis, :] - b[np.newaxis, :, :]), axis=-1) / 2
这种方法不仅在数值上更精确,而且在某些情况下也可能略微提高计算效率,因为它省去了一次平方根运算。
本文深入探讨了在NumPy中计算L2范数平方时,np.linalg.norm 方法可能引入数值不精确性的问题。核心原因在于 np.linalg.norm 内部的平方根操作会产生浮点数误差,即使随后再进行平方也无法完全消除。同时,NumPy的默认打印精度会掩盖这些微小的差异。为了确保数值比较的准确性,我们应避免直接的浮点数相等性判断,转而使用 np.allclose() 进行容忍度比较。此外,对于L2范数的平方计算,直接使用 np.sum(np.square(...)) 是一种更精确且推荐的实践。理解这些浮点数计算的细微之处,对于编写健壮和高精度的数值代码至关重要。
以上就是理解NumPy中np.linalg.norm的数值精度差异及其浮点数比较策略的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号