闲话少说,从图像转换成txt非常简单,直接附代码,如下所示:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22import os
import random
def generate(dir):
files = os.listdir(dir)
print('****************')
print('input :',dir)
print('start...')
listText = open('dataset.txt','w')
random.shuffle(files) #suhffle list
for file in files:
fileType = os.path.split(file)
if fileType[1] == '.txt':
continue
name = file + '\n'
listText.write(name)
listText.close()
print('down!')
print('****************')
if __name__ == '__main__':
generate('test_data/canon/')
上面代码中使用random.shuffle实现文件的乱序输出。代码主要是将文件里的图像名称以乱序的方式存入dataset.txt文件里。
从txt文件中读取图像数据,并且按照epoch方式进行训练,代码如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73import os
import tensorflow as tf
import argparse
class DataPipeLine(object):
def __init__(self,path):
self.path = path
def produce_one_samples(self):
dirname = os.path.dirname(self.path)
with open(self.path,'r') as fid:
flist = [l.strip() for l in fid.xreadlines()]
input_files = [os.path.join(dirname, 'iphone', f) for f in flist]
output_files = [os.path.join(dirname,'canon',f) for f in flist]
input_queue,output_queue = tf.train.slice_input_producer([input_files,output_files],shuffle=True,
seed=1234, num_epochs=None)
input_file = tf.read_file(input_queue)
output_file = tf.read_file(output_queue)
im_input = tf.image.decode_jpeg(input_file,channels=3)
im_output = tf.image.decode_jpeg(output_file,channels=3)
sample = {}
with tf.name_scope('normalize_images'):
im_input = tf.to_float(im_input) / 255.0
im_output = tf.to_float(im_output) / 255.0
inout = tf.concat([im_input,im_output],axis=2)
inout.set_shape([None, None, 6])
inout = tf.image.resize_images(inout,[100,100])
sample['input'] = inout[:, :, :3]
sample['output'] = inout[:, :, 3:]
return sample
def main(args):
sample = DataPipeLine(args.data_dir).produce_one_samples()
samples = tf.train.batch(sample,batch_size=args.batch_size,
num_threads=2,
capacity=32)
loss = tf.reduce_sum(tf.pow(samples['input'] - samples['output'], 2)) / (2 * args.batch_size)
global_step = tf.contrib.framework.get_or_create_global_step()
total_batch = int(400 / args.batch_size)
sv = tf.train.Supervisor()
total_loss = 0
with sv.managed_session() as sess:
step = 0
while True:
if sv.should_stop():
print("stopping supervisor")
break
try:
loss_ = sess.run( loss)
total_loss += loss_
step += 1
print("step:%d,loss:%.2f" %(step,loss_))
if step%total_batch == 0:
print("%d epochs,total loss:%.2f" %((step/total_batch),total_loss))
total_loss = 0
except tf.errors.AbortedError:
print("Aborted")
break
except KeyboardInterrupt:
break
sv.request_stop()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir",default='test_data/dataset.txt',help="The path of input txt")
parser.add_argument("--batch_size", default=40, help="Number Images of each batch")
parser.add_argument("--epochs", default=30, help="The number of epochs")
args = parser.parse_args()
main(args)
这里从txt里获得图像名称,结合路径获得图像名称对应图像的路径,Datapipeline里使用tf.train.slice_input_producer获得图像,这里的shuffle为true,也就是每次从文件里随机取一个图像作为sample,为了产生samples,main里使用了tf.train.batch,这里不使用tf.train.shuffle_batch是为了让每个数据在一个epoch只使用一次。sess里run的内容是计算两个图像之间的MSE,最后每个epoch结束时输出total loss,结果相同,说明每个epoch确实计算的数据都是相同的(一个epoch大小按照总数据量除以batch_size计算)。