现有网络模型的使用及修改(VGG16为例)
2024-01-07 22:28:35
VGG16
修改默认路径
import os
os.environ['TORCH_HOME'] = r'D:\Pytorch\pythonProject\vgg16' # 下载位置
太大了(140多G)不提供直接下载
train_set = torchvision.datasets.ImageNet(root='./data_image_net', split='train', download=True
, transform=torchvision.transforms.ToTensor())
是否预训练
不预训练:采用随机参数
预训练:采用训练好的参数
第一次
第二次
vgg16_false = torchvision.models.vgg16(weights=None)
vgg16_true = torchvision.models.vgg16(weights='DEFAULT') # or weights='IMAGENET1K_V1'
完整代码
import torchvision
import os
os.environ['TORCH_HOME'] = r'D:\Pytorch\pythonProject\vgg16' # 下载位置
# train_set = torchvision.datasets.ImageNet(root='./data_image_net', split='train', download=True
# , transform=torchvision.transforms.ToTensor())
vgg16_false = torchvision.models.vgg16(weights=None)
vgg16_true = torchvision.models.vgg16(weights='DEFAULT') # or weights='IMAGENET1K_V1'
print(vgg16_true)
加一层线性层-nn.Linear
vgg16_true.add_module('add_linear', nn.Linear(1000, 10))
如果想加到classifier里面
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
修改神经网络某层
vgg16_false.classifier[6] = nn.Linear(4096, 10)
改之前
print(vgg16_false)
改之后
文章来源:https://blog.csdn.net/weixin_51788042/article/details/135445000
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!