# LSTM输出输出维度图示

• LSTM输出输出维度图示
• 单向单层
• 单向多层
• 双向单层
• 双向多层
• 参考

## 单向单层

``````# 构造RNN网络，x的维度5，隐层的维度10,网络的层数2
rnn_seq = nn.RNN(5, 10, 1)
# 构造一个输入序列，句长为 6，batch 是 3， 每个单词使用长度是 5的向量表示
# 输入维度为:[seq_len,batch_size,output_dim]
x = torch.randn(6, 3, 5)
#out,ht = rnn_seq(x,h0)
out, ht = rnn_seq(x)  # h0可以指定或者不指定
``````
``````#[T, B, H]
out.shape
``````
``````torch.Size([6, 3, 10])
``````
``````# [1, B, T]
ht.shape
``````
``````torch.Size([1, 3, 10])
``````

## 单向多层

``````# 构造RNN网络，x的维度5，隐层的维度10,网络的层数3
rnn_seq = nn.RNN(5, 10, 3)
# 构造一个输入序列，句长为 6，batch 是 3， 每个单词使用长度是 5的向量表示
# 输入维度为:[seq_len,batch_size,output_dim]
x = torch.randn(6, 3, 5)
#out,ht = rnn_seq(x,h0)
out, ht = rnn_seq(x)  # h0可以指定或者不指定
``````
``````#[T, B, H]
out.shape
``````
``````torch.Size([6, 3, 10])
``````
``````#[num_layer, B, H]
ht.shape
``````
``````torch.Size([3, 3, 10])
``````

## 双向单层

``````# 构造RNN网络，x的维度5，隐层的维度10,网络的层数2
rnn_seq = nn.RNN(5, 10, 1, bidirectional=True)
# 构造一个输入序列，句长为 6，batch 是 3， 每个单词使用长度是 5的向量表示
# 输入维度为:[seq_len,batch_size,output_dim]
x = torch.randn(6, 3, 5)
#out,ht = rnn_seq(x,h0)
out, ht = rnn_seq(x)  # h0可以指定或者不指定
``````
``````# out维度由正向反向的隐藏层拼接形成
out.shape
``````
``````torch.Size([6, 3, 20])
``````
``````# [num_layers * num_directions, batch, hidden_size]
ht.shape
``````
``````torch.Size([2, 3, 10])
``````
``````# 最后时刻的输出
out[-1]
``````
``````tensor([[ 0.3912, -0.3131, -0.5704,  0.1386, -0.0805,  0.2840, -0.4612,  0.4176,
0.1815,  0.3982,  0.3504,  0.0681, -0.3936,  0.5383,  0.0282,  0.3985,
0.3291,  0.3125,  0.3637,  0.1893],
[-0.5409,  0.7553,  0.2176, -0.3243, -0.1724, -0.0350,  0.2422, -0.2549,
0.4105, -0.3549,  0.2171,  0.5521,  0.0122,  0.3783, -0.2583, -0.0181,
0.1647,  0.6133, -0.0935, -0.2087],
[-0.1052,  0.7468, -0.3063, -0.3701, -0.5259, -0.3952, -0.4957,  0.0016,
0.7090, -0.1685,  0.2603,  0.1816, -0.3178,  0.3992, -0.6003, -0.5304,
``````
``````# 正向的的隐藏层输出
ht[0]
``````
``````tensor([[ 0.3912, -0.3131, -0.5704,  0.1386, -0.0805,  0.2840, -0.4612,  0.4176,
0.1815,  0.3982],
[-0.5409,  0.7553,  0.2176, -0.3243, -0.1724, -0.0350,  0.2422, -0.2549,
0.4105, -0.3549],
[-0.1052,  0.7468, -0.3063, -0.3701, -0.5259, -0.3952, -0.4957,  0.0016,
``````
``````# 第0时刻的输出
out[0]
``````
``````tensor([[ 0.4740,  0.4121, -0.5868, -0.2711, -0.2606, -0.4430, -0.5782,  0.8062,
0.7675, -0.7180,  0.5319,  0.4218, -0.5257,  0.5148, -0.7651, -0.1566,
-0.1108,  0.2430, -0.1809, -0.1221],
[ 0.4909,  0.1688, -0.2177, -0.2767,  0.2483, -0.3785, -0.3281,  0.8529,
0.6099, -0.4130,  0.3471,  0.6021, -0.7445,  0.1823, -0.6768,  0.2450,
0.1149,  0.2162, -0.3557, -0.5719],
[-0.2244,  0.7206, -0.0976, -0.5866,  0.3540,  0.1325, -0.5411, -0.7152,
0.3517,  0.3375, -0.8289, -0.7162, -0.1566, -0.3909, -0.4418, -0.2623,
``````
``````# 反向的隐藏层输出
ht[-1]
``````
``````tensor([[ 0.5319,  0.4218, -0.5257,  0.5148, -0.7651, -0.1566, -0.1108,  0.2430,
-0.1809, -0.1221],
[ 0.3471,  0.6021, -0.7445,  0.1823, -0.6768,  0.2450,  0.1149,  0.2162,
-0.3557, -0.5719],
[-0.8289, -0.7162, -0.1566, -0.3909, -0.4418, -0.2623,  0.1497, -0.6729,
``````

## 双向多层

``````# 构造RNN网络，x的维度5，隐层的维度10,网络的层数2
rnn_seq = nn.RNN(5, 10, 2, bidirectional=True)
# 构造一个输入序列，句长为 6，batch 是 3， 每个单词使用长度是 5的向量表示
# 输入维度为:[seq_len,batch_size,output_dim]
x = torch.randn(6, 3, 5)
#out,ht = rnn_seq(x,h0)
out, ht = rnn_seq(x)  # h0可以指定或者不指定
``````
``````# 前向后向拼接
out.shape
``````
``````torch.Size([6, 3, 20])
``````
``````# [num_layers * num_directions, batch, hidden_size]
ht.shape
``````
``````torch.Size([4, 3, 10])
``````
``````out[0]
``````
``````tensor([[ 0.1704,  0.4581,  0.1723,  0.2236,  0.0432,  0.6190,  0.1974,  0.2000,
-0.5012, -0.1075, -0.1713, -0.4623, -0.3120,  0.0759,  0.4959, -0.8103,
-0.2548,  0.4587, -0.2821,  0.7620],
[ 0.4484,  0.4716,  0.0320,  0.1836,  0.1117,  0.5677,  0.0560,  0.2435,
-0.3318, -0.0584, -0.4861,  0.0913, -0.2517,  0.1683,  0.5459, -0.3377,
-0.6199, -0.4051, -0.2039,  0.4189],
[ 0.5295,  0.4061, -0.1754,  0.2779, -0.0318,  0.6160,  0.1777,  0.5757,
-0.1380, -0.2663, -0.6953, -0.0388, -0.2153,  0.5317,  0.2948, -0.1002,
``````
``````ht[2:]
``````
``````tensor([[[-0.0053,  0.7034,  0.1532, -0.0558,  0.5286,  0.8320, -0.2079,
-0.5418, -0.4331,  0.1198],
[ 0.2822,  0.6841, -0.5430,  0.1567,  0.5371,  0.8532,  0.0513,
-0.5214, -0.1258, -0.0206],
[ 0.3767,  0.6684,  0.0900,  0.2732,  0.4522,  0.8421, -0.0925,
-0.1310, -0.2546, -0.0969]],

[[-0.1713, -0.4623, -0.3120,  0.0759,  0.4959, -0.8103, -0.2548,
0.4587, -0.2821,  0.7620],
[-0.4861,  0.0913, -0.2517,  0.1683,  0.5459, -0.3377, -0.6199,
-0.4051, -0.2039,  0.4189],
[-0.6953, -0.0388, -0.2153,  0.5317,  0.2948, -0.1002, -0.6486,