Tensorflow搭建RNN和LSTM网络

写在前面

在尝试搭建了卷积神经网络后,的确可以通过一个简单的卷积神经网络实现某些应用。虽然卷积神经网络拥有强大的功能,但是并不能胜任所有的场景。在本篇博客中,我们将继续学习其它结构的神经网络。

另外,本文首先要阐述神经网络的保存和调用过程。这对于模型训练后进行其它场景下的应用是十分必要的。

1、训练模型的保存和读取过程

训练的模型保存对于其它程序的调用是必要的一步。主要使用tf.train.Saver()函数进行模型的保存。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
#saver the variables
import tensorflow as tf
import numpy as np

# #save to file
# #remember to define the same dtype and shape when restore
W = tf.Variable([[1, 2, 3], [3, 4, 5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1, 2, 3]], dtype=tf.float32, name='biases')
init = tf.initialize_all_variables()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
save_path = saver.save(sess, 'my_net/save_net.ckpt')#指定了保存模型的地址以及文件名称,filename.ckpt
print("Save to path: ", save_path)

保存结果如下:

保存模型生成的文件

各文件介绍如下:

checkpoint:检查点文件。

data:数据文件,保存了网络的数据,例如变量、权值、偏移量等。

index:是一个不可变得字符串表,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据:“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等等。

meta:保存了tensorflow计算图的结构信息。

读取或加载训练模型使用saver.restore(),需要先定义好saver。

1
2
3
4
5
6
7
8
9
10
11
import tensorflow as tf
import numpy as np

W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name='weights')
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name='biases')

saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, 'my_net/save_net.ckpt')#模型保存的位置
print("weights:", sess.run(W))
print("biases:", sess.run(b))

2、循环神经网络RNN

既然有了强大的卷积神经网络,为何还需要循环神经网络?很显然,前面遇到的例子中,我们似乎没有考虑数据和数据之间是否存在某些联系;然而,在许多实际的场景中,数据之间往往存在某些联系。例如正在写的博客,单抽出一个字不能表述完整的内容,但是将许多字连成一个句子就可以表达一个指定的内容。我们可以看到,一个句子可以看作一个序列,这个序列是由一些词和短语构成的,前一个输入和后一个输入之间是有联系的。

当我们读到“史蒂夫·乔布斯”这个名字时,会自然而然的联想到当今最优秀的科技公司的创始人;当逆序读出来时,整个过程会变得极其拗口,自然我们也难以对其进行一些联想和分析。

来看一下普通的神经网络对序列数据的处理。输入一个数据到神经网络,神经网络输出一个结果,直到所有的数据全部输入完成。很显然,这几个数据对应的输出结果之间没有任何的联系,即使这几个数据之间存在某种联系。然而,RNN则将数据之间的关联加以分析。

假设输入的数据时data0、data1、data2、data3,神经网络首先对data0输出一个result0,并且保存一个状态h(0),之后输入data1,此时也会产生一个状态h(1),此时输出的result(0),这时的result(0)是由h(0)和h(1)共同决定的,以此类推,后续的序列分析都会将前面保存的记忆状态进行调用分析输出结果。

循环神经网络RNN

3、长短时记忆网络LSTM

RNN看似很好,但是其存在一个极大的缺点,容易出现梯度消失和梯度爆炸。有关RNN出现梯度消失和梯度爆炸的原因,可以参考莫烦Python中对这个概念的解释。

为了在一定程度上解决RNN带来的这种问题,这时引入了长短时记忆网络(LSTM)。

原始的RNN的隐藏层只有一个状态h,这个状态对短期的输入是非常敏感的。但是,由于其存在梯度消失和梯度下降,导致其对长期的输入敏感度降低,这对距离较远的输入信息是极为不利的。因此LSTM在隐藏层中加入了一个状态c,让这个状态来保存长期的状态。