
我发现 jax 中的 vmap 在应用于多个参数时不会按预期运行。例如,考虑下面的函数:
def f1(x, y, z):
f = x[:, none, none] * z[none, none, :] + y[none, :, none]
return f对于 x = jnp.arange(7), y = jnp.arange(5), z = jnp.arange(3),该函数的输出形状为 (7, 5, 3)。但是,对于以下 vmap 版本:
@partial(vmap, in_axes=(none, 0, 0), out_axes=(1, 2))
def f2(x, y, z):
f = x*z + y
return f它输出此错误:
ValueError: vmap got inconsistent sizes for array axes to be mapped: * one axis had size 5: axis 0 of argument y of type int32[5]; * one axis had size 3: axis 0 of argument z of type int32[3]
有人可以解释一下这个错误背后的原因吗?
对于一个刚进入PHP 开发大门的程序员,最需要的就是一本实用的开发参考书,而不仅仅是各种快速入门的only hello wold。在开发的时候,也要注意到许多技巧和一些“潜规则”。PHP是一门很简单的脚本语言,但是用好它,也要下功夫的。同时,由于PHP 的特性,我一再强调,最NB 的PHP 程序员都不是搞PHP 的。为什么呢?因为PHP 作为一种胶水语言,用于粘合后端 数据库和前端页面,更多需
387
vmap 的语义是它对一个或多个数组执行单个批处理操作。当您指定 in_axes=(none, 0, 0) 时,含义是“同时沿 y 和 z 的前导维度映射”:您看到的错误告诉您 y 和 y 的前导维度具有不同的大小,因此它们不兼容批处理。
您的函数 f1 本质上使用广播来编码三个批处理操作,因此要使用 vmap 复制该逻辑,您将需要 vmap 的三个应用程序。您可以这样表达:
@partial(vmap, in_axes=(0, None, None))
@partial(vmap, in_axes=(None, 0, None))
@partial(vmap, in_axes=(None, None, 0))
def f2(x, y, z):
f = x*z + y
return f
以上就是JAX `vmap` 对于多个参数的意外行为的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号