torch.nn.Flatten 和 torch.flatten 的区别
大约 1 分钟
torch.nn.Flatten()是一个用于对tensor降维的对象, 而torch.flatten()是一个用于对tensor降维的函数.
torch.nn.Flatten
torch.nn.Flatten(start_dim=1, end_dim=-1)
有两个参数表示降维维度的范围, 默认是从第二个维度(index 1)到最后一个维度(index -1)
>>> input = torch.randn(32, 1, 5, 5)
>>> # With default parameters
>>> m = nn.Flatten()
>>> output = m(input)
>>> output.size()
torch.Size([32, 25])
>>> # With non-default parameters
>>> m = nn.Flatten(0, 2)
>>> output = m(input)
>>> output.size()
torch.Size([160, 5])
上述定义了一个四维的tensor, 使用默认的torch.nn.Flatten()
实例化对象是取index 1
到index -1
对应上述的input
就是1, 5, 5
, 降维大小计算是降维前的大小累乘, 因此output
的size是[32, 1*5*5]
, 后者不用默认参数, 使用区间[0, 2]
, 故降维后的size是[32*1*5, 5]
.
torch.flatten
torch.flatten(input: Tensor, start_dim: _int = 0, end_dim: _int = -1)
, 是一个有三个参数的函数, 第一个input
是输入的Tensor对象, 第二和第三个是维度索引的起始和结束.
>>> input = torch.randn(32, 1, 5, 5)
>>> output = torch.flatten(input)
>>> output.size()
torch.Size([800])
上述示例中, 默认的torch.flatten()
使用区间[0, -1]
, 即降至一维, 所以输出size是[32*1*5*5]
.
补充
对于一个Tensor
对象而言, .size()
方法和.shape
属性都可获得该对象的Size
.
>>> input = torch.randn(32, 1, 5, 5)
>>> input.size()
torch.Size([32, 1, 5, 5])
>>> input.shape
torch.Size([32, 1, 5, 5])