tf.data.Dateset中batch, repeat, shuffle顺序问题

引言

本文旨在介绍tf.data.Dataset中batch, repeat, shuffle以及三者的顺序问题。首先介绍了这三个函数单独作用的结果,而后给出了相互作用下的影响。

一、单独作用

shuffle()

 shuffle(buffsize) 用于将数据打乱,其中buffsize的大小越大,数据的混乱程度越高,因为shuffle的实现思路为:** 开辟一可容纳buffsize个数据的缓冲区,初始时将数据的前buffsize个读入缓冲区,而后随机在缓冲区里选择一个输出,同时将数据的第buffsize+1个读入缓冲区**。由此不难理解,如果buffsize很小,比如为1时就根本没有打乱。给出示例代码和结果如下:

data=tf.range(0,10)
data=tf.data.Dataset.from_tensor_slices(data)
data1=data.shuffle(5)
for i in data1:
    print(i.numpy())
'''
#结果
4
0
5
2
3
1
6
8
7
9
'''

可以测试,如果多次执行代码中的输出程序,每次的打乱结果都会发生变化,但是第一个输出的值永远都是在0~4的范围之内,这是因为我们设置的buffsize=5。

repeat()

repeat(count) 用于将数据重复count次,相当于我们训练时的epoch,示例代码及结果如下:

data1=data.repeat(2)
for i in data:
    print(i.numpy())
'''
#结果
0 1 2 3 4 5 6 7 8 9 
0 1 2 3 4 5 6 7 8 9 
'''

batch()

batch(batch_size)用于将数据划分为多个batch,同时tensorflow中有着很好的调整功能,当最后一个batch不满足batchsize时就以当前长度输出。

data=tf.range(0,10.)[:,None]
data=tf.data.Dataset.from_tensor_slices(data)
data1=data.batch(4)
for i in data1:
    print(i.numpy())
'''
#结果
[[0.]
 [1.]
 [2.]
 [3.]]
[[4.]
 [5.]
 [6.]
 [7.]]
[[8.]
 [9.]]

可以看到最后一个batch只含有2个数据。

相互作用

1. 先repeat再shuffle

 其结果就是对repeat后的数据进行打乱,这会使得不同epcoh间的数据被打乱,即前一个epcoh中数据未加载完,下一个epoch中数据可能插入,导致一个epoch中可能数据重复多次。

temp1=data.repeat(2).shuffle(5)
for i in temp1:
    print(i.numpy())
# 结果
# [4.]	[0.]	[2.]	[7.]	[8.]	[5.]	[3.]	[6.]	[1.]
#[1.]	[0.]	[2.]	[5.]	[4.]	[9.]	[9.]	[3.]	[6.]	[7.]	[8.]

可以看出,在第一个epoch还未结束(未出现9时)就已经出现了下个epoch的0

2.先shuffle再repeat

  先shuffle再repeat,epoch内部打乱,一定先输出完一个epoch内所有值。

temp2=data.shuffle(5).repeat(2)
for i in temp2:
   print(i.numpy())
#结果
# [2.]	[4.]	[5.]	[6.]	[0.]	[1.]	[3.]	[9.]	[8.]	[7.]	
#[3.]	[1.]	[4.]	[7.]	[0.]	[6.]	[9.]	[2.]	[8.]	[5.]

3. 先batch再repeat

 先batch再repeat,为对于batch的复制。

temp3=data.batch(4).repeat(2)
for i in temp3:
    print(i.numpy())
#结果
# [[0.]	 [1.]	 [2.]	 [3.]]	[[4.]	 [5.]	 [6.]	 [7.]]	[[8.]	 [9.]]	
#[[0.]	 [1.]	 [2.]	 [3.]]	[[4.]	 [5.]	 [6.]	 [7.]]	[[8.]	 [9.]

4. 先repeat再batch

  先repeat再batch是对于重复后数组的分组。

temp4=data.repeat(2).batch(4)
for i in temp4:
    print(i.numpy())
# [[0.]	 [1.]	 [2.]	 [3.]]	[[4.]	 [5.]	 [6.]	 [7.]]	[[8.]	 
#[9.]	 [0.]	 [1.]]	[[2.]	 [3.]	 [4.]	 [5.]]	[[6.]	 [7.]	 
#[8.]	 [9.]]

5.先batch再shuffle

 先batch再shuffle是对不同组间的打乱

temp5=data.batch(4).shuffle(5)
for i in temp5:
    print(i.numpy())
# [[4.]	 [5.]	 [6.]	 [7.]]	
#[[0.]	 [1.]	 [2.]	 [3.]]	
#[[8.]	 [9.]]

6.先shuffle再batch

 先shuffle再batch是在对打乱后的数据分组

temp6=data.shuffle(5).batch(4)
for i in temp6:
    print(i.numpy())
# [[3.]	 [2.]	 [1.]	 [7.]]	
#[[6.]	 [5.]	 [4.]	 [0.]]	
[[9.]	 [8.]]

参考文献

【Tensorflow 2.0 正式版教程】tf.data.Dataset的基本使用方法
tf.data.Dataset关于batch,repeat,shuffle的讲解
#深入探究# Tensorflow.Data.shuffle 方法的实现原理和 buffer_size 参数的作用
加载自定义图片数据集到Dataset

你可能感兴趣的