知道如何使用 pydantic 输入提示 tf 张量吗?尝试了默认的 tf.Tensor
RuntimeError: no validator found for <class 'tensorflow.python.framework.ops.Tensor'>, see `arbitrary_types_allowed` in Config
和 tf.flaot32
RuntimeError: error checking inheritance of tf.float32 (type: DType)
查看 pydantic 中的文档,我相信需要定义类似这个任意类的东西...
class Tensor:
def __init__(self, Tensor):
self.Tensor = Union[
tensorflow.python.framework.ops.Tensor,
tensorflow.python.framework.sparse_tensor.SparseTensor,
tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor,
tensorflow.python.framework.ops.EagerTensor,
]
在 main 中包含以下内容..
class Main(BaseModel):
tensor : Tensor
class Config:
arbitary_types_allowed = True
你可以使用
annotation: Type[T]
来提示任意类型,包括
tensorflow.Tensor
,而不必创建自定义类来进行验证。以下是操作方法:
from typing import Type
import tensorflow as tf
from pydantic import BaseModel
class Model(BaseModel):
tensor: Type[tf.Tensor]
class Config:
arbitrary_types_allowed = True
tensor = tf.constant([1, 2, 3])
model = Model(tensor=tensor)
print(model)
print(type(model.tensor))
这将输出:
tensor=<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], ...)>
<class 'tensorflow.python.framework.ops.EagerTensor'>
解释:
-
arbitrary_types_allowed = True
: 此配置至关重要。它告诉 Pydantic 不要尝试验证tf.Tensor
类型,因为 Pydantic 默认情况下不知道如何处理它。 -
Type[tf.Tensor]
: 我们使用Type[tf.Tensor]
来提示类型。这意味着tensor
属性应该接收tf.Tensor
类本身,而不是该类的实例。
重要注意事项:
-
验证的缺乏
: 虽然此方法允许你使用
tf.Tensor
,但 Pydantic 不会执行任何类型验证。它只会检查是否传递了tf.Tensor
类型的对象。 -
其他 TensorFlow 类型
: 你可以使用相同的方法来提示其他 TensorFlow 类型,例如
tf.Variable
、tf.Operation
等。只需将tf.Tensor
替换为你需要的特定类型。
如果你需要更严格的类型检查或自定义验证,你可以考虑创建自定义的 Pydantic 类型或使用 Pydantic 的验证器来实现。
标签:python,tensorflow,python-typing,pydantic From: 72184958