torch.nn.Flatten 和 torch.flatten 的区别
大约 1 分钟
torch.nn.Flatten()是一个用于对tensor降维的对象, 而torch.flatten()是一个用于对tensor降维的函数.
torch.nn.Flatten
torch.nn.Flatten(start_dim=1, end_dim=-1) 有两个参数表示降维维度的范围, 默认是从第二个维度(index 1)到最后一个维度(index -1)
>>> input = torch.randn(32, 1, 5, 5)
>>> # With default parameters
>>> m = nn.Flatten()
>>> output = m(input)
>>> output.size()
torch.Size([32, 25])
>>> # With non-default parameters
>>> m = nn.Flatten(0, 2)
>>> output = m(input)
>>> output.size()
torch.Size([160, 5])
上述定义了一个四维的tensor, 使用默认的torch.nn.Flatten()实例化对象是取index 1到index -1对应上述的input就是1, 5, 5, 降维大小计算是降维前的大小累乘, 因此output的size是[32, 1*5*5], 后者不用默认参数, 使用区间[0, 2], 故降维后的size是[32*1*5, 5].
torch.flatten
torch.flatten(input: Tensor, start_dim: _int = 0, end_dim: _int = -1), 是一个有三个参数的函数, 第一个input是输入的Tensor对象, 第二和第三个是维度索引的起始和结束.
>>> input = torch.randn(32, 1, 5, 5)
>>> output = torch.flatten(input)
>>> output.size()
torch.Size([800])
上述示例中, 默认的torch.flatten()使用区间[0, -1], 即降至一维, 所以输出size是[32*1*5*5].
补充
对于一个Tensor对象而言, .size()方法和.shape属性都可获得该对象的Size.
>>> input = torch.randn(32, 1, 5, 5)
>>> input.size()
torch.Size([32, 1, 5, 5])
>>> input.shape
torch.Size([32, 1, 5, 5])
