
在科学计算中,浮点数精度是一个常见且关键的问题。特别是在使用像numpy这样的库进行高性能数值运算时,即使是看似等效的操作也可能因为底层实现细节而产生微小的数值差异。本文将深入探讨一个具体案例:在使用np.linalg.norm计算向量范数的平方时,与直接计算平方和相比,可能引入肉眼不可见的数值不一致。
考虑以下两个NumPy数组:
import numpy as np
a = np.array([[ 0, 1, 10, 2, 5]])
b = np.array([[ 0, 1, 18, 15, 5],
[13, 9, 23, 3, 22],
[ 2, 10, 17, 4, 8]])我们通过两种方法计算a和b之间某种距离的平方,并观察它们的输出。
方法一:使用 np.linalg.norm
m1 = -np.linalg.norm(a[:, np.newaxis, :] - b[np.newaxis, :, :], axis=-1) ** 2 / 2
方法二:直接计算平方和
m2 = -np.sum(np.square(a[:, np.newaxis, :] - b[np.newaxis, :, :]), axis=-1) / 2
当我们打印这两个结果时,它们看起来是完全相同的:
print(m1) # 输出: [[-116.5 -346. -73.5]] print(m2) # 输出: [[-116.5 -346. -73.5]]
然而,当我们尝试使用np.array_equal来检查它们是否相等时,结果却出乎意料:
>>> np.array_equal(m1, m2) False
这表明尽管print()函数显示它们相同,但m1和m2在底层数值上存在差异。更有趣的是,如果我们将一个字面量数组与m1和m2进行比较:
>>> sanity_check = np.array([[-116.5, -346. , -73.5]]) >>> np.array_equal(sanity_check, m1) False >>> np.array_equal(sanity_check, m2) True
这进一步证实了m1是“异类”,它与预期的精确值不完全相等。
这种差异的根本原因在于浮点数的计算方式以及np.linalg.norm的内部实现。
np.linalg.norm的内部机制np.linalg.norm函数在计算范数时,通常会涉及到开方操作。例如,对于一个向量v,其L2范数(欧几里得范数)定义为sqrt(sum(v_i^2))。因此,np.linalg.norm(v)**2实际上是sqrt(sum(v_i^2))**2。 问题在于,在计算机中,sqrt(x)**2并不总是精确地等于x,尤其当x是一个浮点数且其平方根无法精确表示时。即使是微小的舍入误差,在后续运算中也可能累积。
我们可以通过一个简单的例子来验证这一点:
>>> np.sqrt(8**2 + 13**2)**2 232.99999999999997 >>> 8**2 + 13**2 233
这里,8**2 + 13**2的结果是整数233。但经过sqrt再square操作后,结果变成了232.99999999999997,一个微小的误差被引入。这就是m1中np.linalg.norm引入误差的机制。
np.sum(np.square(...))的优势 相比之下,方法二np.sum(np.square(a[:, np.newaxis, :] - b[np.newaxis, :, :]), axis=-1)直接计算了差值的平方和,没有引入开方操作,因此避免了上述的浮点误差来源,从而得到了更精确的结果。
print()函数显示m1和m2相同,是因为NumPy的默认打印选项对浮点数进行了舍入。NumPy通过np.set_printoptions来控制数组的打印格式,其中precision参数决定了浮点数打印的有效数字位数。
>>> np.get_printoptions()
{'edgeitems': 3, 'threshold': 1000, 'floatmode': 'maxprec', 'precision': 3, 'suppress': False, 'linewidth': 75, 'nanstr': 'nan', 'infstr': 'inf', 'sign': '-', 'formatter': None, 'legacy': False}默认情况下,precision通常设置为8(或在某些版本中为3,如本例所示),这意味着只会打印小数点后指定位数的数字。如果实际差异小于这个精度,print()函数就会将它们显示为相同。
为了揭示m1和m2的实际数值差异,我们可以将它们转换为列表,这会显示更完整的浮点数表示:
>>> m1.tolist() [[-116.49999999999999, -346.0, -73.5]] >>> m2.tolist() [[-116.5, -346.0, -73.5]]
现在,差异清晰可见:m1的第一个元素是-116.49999999999999,而m2的对应元素是精确的-116.5。
浮点数比较: 永远不要直接使用==或np.array_equal来比较浮点数,因为微小的精度差异可能导致意外的False结果。 应该使用带有容差的比较函数,例如np.allclose():
>>> np.allclose(m1, m2) True
np.allclose()允许你指定一个绝对容差(atol)和一个相对容差(rtol),只要两个数组的对应元素在这些容差范围内,就认为它们相等。
选择合适的计算方法: 在进行数值计算时,如果存在多种等效的数学表达式,应优先选择那些能避免引入额外浮点误差的方法。在本例中,直接计算平方和(np.sum(np.square(...)))优于通过np.linalg.norm再平方。
理解NumPy打印选项: 了解np.set_printoptions和np.get_printoptions的作用,可以帮助你更好地理解NumPy数组的显示方式,避免被默认的舍入输出所误导。在调试精度问题时,可以临时增加precision或使用tolist()来查看完整数值。
本教程通过一个具体的NumPy案例,深入探讨了浮点数精度在数值计算中的重要性。我们发现,np.linalg.norm由于其内部的开方再平方操作,可能引入微小的浮点误差,导致与直接平方和计算的结果不一致。同时,NumPy的默认打印机制可能隐藏这些差异。理解这些细节对于编写健壮、精确的科学计算代码至关重要。在处理浮点数时,务必使用np.allclose进行比较,并根据具体情况选择最优的计算路径以最小化误差。
以上就是NumPy浮点运算精度探究:np.linalg.norm与直接平方和的细微差异的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号