参考转自
https://blog.csdn.net/u010960155/article/details/81707498
实验如下
import tensorflow as tf
import numpy as np
def dynamic_rnn(rnn_type='lstm'):
# 创建输入数据,3代表batch size,6代表输入序列的最大步长(max time),8代表每个序列的维度
X = np.random.randn(3, 6, 4)
# 第二个输入的实际长度为4。在此处也就是time_step 设定为4了,不再是6.注意看返回结果state,不再是步长内的第6个而是第4个。
X[1, 4:] = 0
#记录三个输入的实际步长
X_lengths = [6, 4, 6]
rnn_hidden_size = 5
if rnn_type == 'lstm':
cell = tf.contrib.rnn.BasicLSTMCell(num_units=rnn_hidden_size, state_is_tuple=True)
else:
cell = tf.contrib.rnn.GRUCell(num_units=rnn_hidden_size)
outputs, last_states = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
sequence_length=X_lengths,
inputs=X)
with tf.Session() as session:
session.run(tf.global_variables_initializer())
o1, s1 = session.run([outputs, last_states])
print(np.shape(o1))
print(o1)
print(np.shape(s1))
print(s1)
if __name__ == '__main__':
dynamic_rnn(rnn_type='lstm')
lstm模式下
outputs 为 [batch_size , time_step, rnn_unit]
last_states (c,h) 结构为[ 2,batch_size ,rnn_unit]
last_states.h [batch_size ,rnn_unit]
last_states.c [batch_size ,rnn_unit]
last_states.h 和outputs 最后一个相同
outputs:包含了所有时刻的输出 H ,
states :包含了 "每个time_step内"最后一个时刻的输出 H 和 C
阅读(1894) | 评论(0) | 转发(0) |