torch中 squeeze 和 unsqueeze 的不同用法

2023-12-14 17:00:23

1 squeeze()的用法

torch.squeeze
torch.squeeze(input, dim = None, *, out = None)
. input:输入的张量
. dim 选择需要降维的维度,默认是None

为什么要降维
如果维度是 1 ,那么,1 仅仅起到扩充维度的作用,而没有其他用途,因而,在进行降维操作时,为了加快计算,是可以去掉这些 1 的维度。在多维张量中,如果某一个维度是1,那么这个维度是为了扩充维度,所以为了加快计算,进行降维操作时可以去掉1的维度。

1)对指定的维度进行降维

import torch
A = torch.ones((1,2,3,1,4,2))
A.shape
torch.Size([1, 2, 3, 1, 4, 2])

(a)如果某个维度为1,则对此维度进行降维

B = torch.squeeze(A,dim=0)
B.shape
torch.Size([2, 3, 1, 4, 2])

C = torch.squeeze(A,dim=3)
C.shape
torch.Size([1, 2, 3, 4, 2])

(b) 某个维度不为1,则无法对此维度进行降维

D = torch.squeeze(A,dim=2)
D.shape
torch.Size([1, 2, 3, 1, 4, 2])

2) 默认使用torch.squeeze ,不指定维度

E = torch.squeeze(A)
E.shape
torch.Size([2, 3, 4, 2])

不指定维度,则会将所有为1 的维度全部降维,保留不是1 的维度

2.torch.unsqueeze的用法

torch.unsqueeze 是为了升维
torch.unsqueeze(input,dim)
input: 插入的张量
dim: 指定在某个维度进行升维

W = torch.ones((2,3,5))
W.shape
torch.Size([2, 3, 5])
M = torch.unsqueeze(W, dim = 0)
M.shape
torch.Size([1, 2, 3, 5])

文章来源:https://blog.csdn.net/weixin_43588171/article/details/134997315
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。