使用tensorflow实现logistic回归

   逻辑回归原理很简单,这里不再赘述,我使用tensorflow的思路和前面一样,还是利用Supervisor模块(这个确实好用啊),argparser和logging日志模块。实现代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import sys
reload(sys)
sys.setdefaultencoding('utf-8')

import tensorflow as tf
import logging
import argparse
import os
import matplotlib.pyplot as plt
import numpy as np
logging.basicConfig(format="[%(process)d] %(levelname)s %(filename)s:%(lineno)s | %(message)s")
log = logging.getLogger('train')
log.setLevel(logging.INFO)

data = np.mat([[0.697,0.460,1],
[0.774,0.376,1],
[0.634,0.264,1],
[0.608,0.318,1],
[0.556,0.215,1],
[0.403,0.237,1],
[0.481,0.149,1],
[0.437,0.211,1],
[0.666,0.091,0],
[0.243,0.267,0],
[0.245,0.057,0],
[0.343,0.099,0],
[0.639,0.161,0],
[0.657,0.198,0],
[0.360,0.370,0],
[0.593,0.042,0],
[0.719,0.103,0]])

def log_hook(sess,log_fetches):
data = sess.run(log_fetches)
loss = data['loss']
step = data['step']
log.info('Step {}|loss = {:.4f}'.format(step,loss))


def logistic_regression(W,b,x):
pred = 1/(1+tf.exp(-(tf.matmul(x,W)+b)))
return pred


def main(args):
W = tf.Variable(tf.random_normal([2,1],stddev=0.1))
b = tf.Variable(tf.random_normal([1],stddev=0.1))
x = tf.to_float(data[:,0:2])
y = tf.to_float(data[:,2])
global_step = tf.contrib.framework.get_or_create_global_step()
pred = logistic_regression(W,b,x)
loss = tf.reduce_sum(-tf.reshape(y,[-1,1])*tf.log(pred)-(1-tf.reshape(y,[-1,1]))*tf.log(1-pred))
train_op = tf.train.GradientDescentOptimizer(args.learning_rate).minimize(loss,global_step)

tf.summary.scalar('loss',loss)
log_fetches ={
"W":W,
"b":b,
"loss":loss,
"step":global_step
}
sv = tf.train.Supervisor(logdir = args.checkpoint_dir,save_model_secs=args.checkpoint_interval,
save_summaries_secs=args.summary_interval)

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

with sv.managed_session(config=config) as sess:
sv.loop(args.log_interval,log_hook,[sess,log_fetches])
while True:
if sv.should_stop():
log.info('stopping supervisor')
try:
WArr,bArr,_ = sess.run([W,b,train_op])
x0 = np.array(data[:8])
x0_ = np.array(data[8:])
plt.scatter(x0[:,0],x0[:,1],c='r',label='+')
plt.scatter(x0_[:,0],x0_[:,1],c='b',label='-')
x1 = np.arange(-0.2,1.0,0.1)
y1 = (-bArr-WArr[0]*x1)/WArr[1]
plt.plot(x1,y1)
plt.pause(0.01)
plt.cla()
except tf.errors.AbortedError:
log.error('Aborted')
break
except KeyboardInterrupt:
break
chkpt_path = os.path.join(args.checkpoint_dir, 'on_stop.ckpt')
log.info("Training complete, saving chkpt {}".format(chkpt_path))
sv.saver.save(sess, chkpt_path)
sv.request_stop()




if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--learning_rate",default=5e-2,type=float,help='learning rate for the stochastic gradient update.')
parser.add_argument('--checkpoint_dir', default='summary/', help='directory of summary to save.')
parser.add_argument('--summary_interval', type=int, default=1, help='interval between tensorboard summaries (in s)')
parser.add_argument('--log_interval', type=int, default=1, help='interval between log messages (in s).')
parser.add_argument('--checkpoint_interval', type=int, default=20, help='interval between model checkpoints (in s)')

args = parser.parse_args()
main(args)

效果如图所示:

   大部分功能在线性回归那部分已经说了,这里再补充一些supervisor的用法。
   使用Supervisor的步骤一般是:
(1)创建一个Supervisor对象,将要保存checkpoints以及summaries的目录路径传递给该对象。
(2)利用tf.train.Supervisor.managed_session向supervisor请求一个session。
(3)利用该session来执行训练的op,在每一步都核查supervisor是否要求训练结束。
   图中有一个name为global_step的整型变量,服务会使用它的值来衡量执行的训练步数。sv.should_stop()的判读作用是,当shold_stop()条件设置为true时,这些服务线程中提起的异常会被报告给supervisor。服务线程会通知该条件并且恰当地终止。
   sv.loop()里,第一个参数是每多久该线程运行一次,第二个参数是每次的打印目标,第三个参数是每次要训练的参数。
   初始的模型加载参数,保存参数等等在sv = tf.train.Supervisor就已经准备好了,更多的supervisor功能可以参考官方文档。

参考:
1、https://blog.csdn.net/mijiaoxiaosan/article/details/75021279
2、https://www.cnblogs.com/wuzhitj/p/6648641.html

0%