我正在尝试运行用 jax 编写的模型, https://github.com/lindermanlab/S5 但是,我遇到了一些错误,显示
Traceback (most recent call last):
File "/Path/run_train.py", line 101, in <module>
train(parser.parse_args())
File "/Path/train.py", line 144, in train
state = create_train_state(model_cls,
File "/Path/train_helpers.py", line 135, in create_train_state
params = variables["params"].unfreeze()
AttributeError: 'dict' object has no attribute 'unfreeze'
我尝试通过
import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
model = nn.Dense(features=3)
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 2)))
params_unfrozen = flax.traverse_util.unfreeze(params)
复制此错误,错误如下:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: module 'flax.traverse_util' has no attribute 'unfreeze'
我正在使用:
flax 0.7.4
jax 0.4.13
jaxlib 0.4.13+cuda12.cudnn89
我认为这是与亚麻版本有关的问题,但有谁知道到底发生了什么?任何帮助表示赞赏。如果您需要任何进一步的信息,请告诉我
是对的,问题出在 Flax 版本上。
flax.traverse_util.unfreeze
函数已在 Flax 0.8.0 版本中移除。因此,在 Flax 0.7.4 中看到的错误消息是因为该函数在该版本中不存在。
要解决此问题,可以使用
flax.core.frozen_dict.unfreeze
函数,该函数在 Flax 0.7.4 及更高版本中可用。以下是如何更新代码:
import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
from flax.core import frozen_dict
model = nn.Dense(features=3)
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 2)))
params_unfrozen = frozen_dict.unfreeze(params)
在 S5 代码库的上下文中,需要将
train_helpers.py
文件中的以下行:
params = variables["params"].unfreeze()
替换为:
params = frozen_dict.unfreeze(variables["params"])
这应该可以解决
AttributeError
问题,并允许运行 S5 模型。
另外,还可以考虑将 Flax 升级到最新版本,因为它包含性能改进、错误修复以及新功能。
标签:python,attributeerror,jax,flax From: 78256559