Long time no see!
我发现虽然大四了,但是课还是挺多的,每天倒腾的时间也不长。所以很久没更新了,再者前几天在学linux的一点点知识,比如shell,vim还有tmux。学艺不精,又不知从何学起,不过倒是可以看懂别人写的shell脚本了,也算进步了吧(微笑)。
最近遇到的一个问题是tf.estimator的使用,tfnet代码里面
tfnet_est.predict
如果输入数据集里面只有low_wav会报错,还必须得输入(high_wav,low_wav)对。很奇怪,因为predict应该只有low_wav,在逻辑上才是正确的。
TFNetEstimator
是继承tf.estimator.Estimator的一个类,改了model
我们现在需要调用TFNetEstimator.predict
用input_fn=lambda: get_dummy_dataset().make_one_shot_iterator().get_next()
是没有问题的
而用
dset = ds.single_file_dataset(LQ_AUDIO_FILE,)
#RunConfig for more more printing since we are only training for very few steps
config = tf.estimator.RunConfig(log_step_count_steps=1)
tfnet_est = TFNetEstimator(**nets.default_net(), config=config,
model_dir=DUMMY_MODEL_PATH)
preds = tfnet_est.predict(input_fn=lambda: dset.make_one_shot_iterator().get_next())
会报错。下面看下这两个dset的细节:
1)ds.single_file_dataset:
def single_file_dataset(filename, upsample_rate=2, seg_length=8192, batchsize=16, **kwargs):
"""Loads a single audio file and process it in order, use for prediction"""
DEBUG("Ignored args: " + str(kwargs))
audio_in = filters.upsample(_load_wav(filename), upsample_rate)
audio_len, channels = audio_in.shape
padlen = seg_length - audio_len%seg_length
audio_padded = np.pad(audio_in, [(0, padlen), (0, 0)], 'constant')
audio_segs = audio_padded.reshape((-1, seg_length, channels))
def _gen():
for seg in audio_segs:
yield seg
dset = tf.data.Dataset.from_generator(_gen,
output_types=tf.float32,
output_shapes=[seg_length, channels])
dset = dset.batch(batchsize)
return dset
2)get_dummy_dataset:
def get_dummy_dataset(length=8192, channels=1, count=16,
batchsize=4, repeat=200,
drop_remainder=True
):
"""Dummy dataset generator for use in unit tests"""
dummy_hr = np.array(np.linspace(0, 1, length)[:, np.newaxis], dtype=np.float32)
dummy_hr = np.hstack([dummy_hr for _ in range(channels)])
dummy_lr = dummy_hr.copy()
dummy_lr[1::2] = 0
dummy_train = [(dummy_lr.copy(), dummy_hr.copy()) for _ in range(count)]
dummy_dset = tf.data.Dataset.from_generator(lambda: ((l, h) for l, h in dummy_train),
output_types=(tf.float32, tf.float32),
output_shapes=([length, channels],
[length, channels]))
#16 samples per epoch, 2 epochs, batch size 4 -> 8 iterations
dummy_dset = dummy_dset.repeat(repeat).batch(batchsize, drop_remainder=drop_remainder)
return dummy_dset
https://tensorexamples.com/2020/07/27/Using-the-tf.data.Dataset.html
先看看Dataset