pytorch中常见的合并和分割

2023-12-22 11:48:43
import torch

'''
1-数据合并:cat(沿着维度);stack(增加维度)
2-数据分割:split(按照长度);chunk(按照数量)

'''

a1 = torch.rand(2, 3, 28, 28)
a2 = torch.rand(4, 3, 28, 28)
'''
torch.cat([a1, a2], dim=0)
dim = 指定拼接的维度
注意,使用torch.cat进行拼接时除了拼接维度可以不同外,其他的维度必须相同
'''


def zqb_cat():
    print(a1.shape, a2.shape

文章来源:https://blog.csdn.net/zqb_123/article/details/135148142
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。