云端的数据库存储在google的服务器,所以无法通过tfds.load('mnist', split='train')这样的方式直接加载,下面是一些变通的方式来读取。
一、MNIST数据库
1.我用tfds.load去加载本地的mnist数据库,mnist/3.0.1应该具有什么样的数据格式
答:
如果你正在使用`tfds.load`来加载本地的MNIST数据库,那么你需要确保数据是按照TensorFlow的数据格式进行组织的。在TensorFlow Dataset (TFDS)透明地下载和序列化数据集后,它会将数据保存为`tfrecord`文件,配以描述数据集特征的`dataset_info.json`文件,这些文件共同组成了一个数据集的版本(即 'mnist/3.0.1')。
下面是MNIST数据集的标准TFDS目录结构:
```
[mnist]
│
└───3.0.1
│
├───以tfrecord格式存储的实际数据 ('mnist-train.tfrecord-00000-of-00001', 'mnist-test.tfrecord-00000-of-00001')
│
└─── 'dataset_info.json' (数据集的元信息)
```
在此结构中,`mnist-train.tfrecord-00000-of-00001`和`mnist-test.tfrecord-00000-of-00001`是经TFDS处理值的MNIST训练和测试数据,以tfrecord的形式存储。`dataset_info.json`文件包含有关数据集的信息,如版本,特征描述,数据集大小等。
请注意,通常情况下,我们并不直接创建这样的文件和目录结构,而是使用`tfds.builder`或`tfds.load`来自动下载和处理原始数据。
如果你有本地MNIST的原始数据,你可能需要编写自定义的`tfds.core.DatasetBuilder`来处理和加载这些数据,或者将你的本地数据转换成tfrecord格式并按照上述目录结构进行存储。
2.我在本地拥有train-images-idx3-ubyte.gz这样类型的数据,我如何将我的数据按照上述的目录结构进行转换
答:对于 `train-images-idx3-ubyte.gz` 这种格式的数据,它是原始的 MNIST 数据库的压缩格式,你需要先解压这类文件,然后将解压出的文件转换成 `tfrecord` 的格式。以下是一种可能的步骤:
1). **解压文件**:
你可以通过 gzip 命令(在 Linux 操作系统下)或者相关的压缩工具(在 Windows 操作系统下)来进行解压。
对于 Linux 用户来说,可以使用以下命令来解压文件:
```
gunzip train-images-idx3-ubyte.gz
```
2). **转换为 `tfrecord` 格式**:
解压之后,你需要编写一个转换脚本,将解压出的 `.ubyte` 文件转换成 `tfrecord` 文件,这一步需要使用到 TensorFlow 的 API。
以下是一个 Python 代码示例,用于将 MNIST 数据库转换为 `tfrecord` 文件:
import numpy as np import tensorflow as tf import os os.environ['CUDA_VISIBLE_DEVICES']='0' tf.config.experimental.set_visible_devices([], 'GPU') # 在此处填入你的解压后的文件 imgs_file ='' labels_file = '' def _read32(bytestream): dt = np.dtype(np.uint32).newbyteorder('>') return np.frombuffer(bytestream.read(4), dtype=dt)[0] def get_data(filename): with open(filename, 'rb') as f: magic = _read32(f) if magic != 2051: raise ValueError('Invalid magic number %d in MNIST image file: %s' % (magic, f.name)) num_images = _read32(f) rows = _read32(f) cols = _read32(f) buf = f.read(rows * cols * num_images) data = np.frombuffer(buf, dtype=np.uint8) data = data.reshape(num_images, rows, cols, 1) return data def get_labels(filename): with open(filename, 'rb') as f: magic = _read32(f) if magic != 2049: raise ValueError( 'Invalid magic number %d in MNIST label file: %s' % (magic, f.name)) num_items = _read32(f) buf = f.read(num_items) labels = np.frombuffer(buf, dtype=np.uint8) return labels def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value): if isinstance(value, type(tf.constant(0))): value = value.numpy() return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) # 把图片数据和标签数据转换为 tfrecord def convert_to_tfrecord(images, labels, output_file,datasetSize): """Converts a file to TFRecords.""" print('Generating %s' % output_file) with tf.io.TFRecordWriter(output_file) as record_writer: for i in range(datasetSize): image_raw = images[i].tobytes() # 序列化图片数据 example = tf.train.Example(features=tf.train.Features(feature={ 'height': _int64_feature(28), 'width': _int64_feature(28), 'depth': _int64_feature(1), 'label': _int64_feature(int(labels[i])), 'image_raw': _bytes_feature(image_raw)})) record_writer.write(example.SerializeToString()) images = get_data(imgs_file) labels = get_labels(labels_file) convert_to_tfrecord(images, labels, 'output.tfrecord',labels.shape[0])
以上是一个基本的转换脚本,实际投入使用时可能需要进行一些修改来适应你的需求。
此外,你可能要参考 TensorFlow 的官方指南和 API 参考文档来了解如何使用 `tfrecord`。在将数据预处理为 `tfrecord` 后,你可以建立一个新目录如 `TFDS_dir/mnist/3.0.1/` ,然后将 `tfrecord` 文件放入此目录,同时考虑创建包含必要数据信息的 `dataset_info.json` 文件。
3.回到前面的问题,我已经创建了mnist数据库中的tfrecord格式的文件,接下来我该如何创建'dataset_info.json' 文件?
答:
在创建针对MNIST数据库的'dataset_info.json'文件时,您需要知道数据集有关的特定信息。例如,对于MNIST数据库,总的示例数量是70000,其中60000用于训练,10000用于测试。并且它有10个类别,分别对应数字0到9。
以下是您可以用来创建'dataset_info.json'的Python代码示例:
import json # 创建描述MNIST数据库的字典 dataset_info = { "name": "MNIST", "version": "3.1.0", "size": "11MB", # You can specify the size if you know it "num_samples": { "train": 60000, "test": 10000 }, "num_classes": 10, "labels": [str(i) for i in range(10)] # labels are string of number 0 to 9 } # 将这个字典写入到json文件中 with open('dataset_info.json', 'w') as f: json.dump(dataset_info, f)
这样,您就创建了一个名为 'dataset_info.json' 的文件,其中包含了MNIST数据集的相关信息。这在后续训练模型和评估结果时十分有用,因为您可以直接读取这个文件,而不是每次都需要手动查找这些信息。
标签:tfrecord,tfds,数据库,labels,dataset,json,train,MNIST,加载 From: https://www.cnblogs.com/qq552048250/p/17747568.html