盘点 Pytorch Vision 中的图像预训练模型
PyTorch Vision
库提供了许多经过预训练的视觉模型,包括图像分类、目标检测、语义分割等。
一、图像分类
图像分类是计算机视觉领域中最基础且常见的任务之一,其目标是将图像分配到预定义的类别中。在图像分类任务中,计算机模型需要学习从输入图像中提取特征并做出正确分类的能力。这一任务在许多应用中都是至关重要的,例如图像检索、物体识别、人脸识别等。
PyTorch Vision
库为图像分类任务提供了丰富的模型选择,从经典的AlexNet
和VGG
到深度残差网络ResNet
、密集连接网络DenseNet
,以及轻量级网络SqueezeNet
,用户可以根据任务需求选择适当的模型。
Pytorch
中支持的预训练模型在:torchvision.models
下:
基本使用,以下面图像为例:
import torch
import torchvision.transforms as transforms
from PIL import Image
from torchvision import models
import requests
import matplotlib.pyplot as plt
def load_labels():
url = 'https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json'
response = requests.get(url)
if response.status_code == 200:
return response.json()
return []
def main():
image_path = "./images/cat.jpg"
# 加载预训练的 ResNet-50 模型
model = models.resnet50(pretrained=True)
model.eval()
# 图像预处理
input_image = Image.open(image_path)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # 添加 batch 维度
# 将输入传递给模型并获取预测结果
with torch.no_grad():
output = model(input_batch)
# 获取预测结果中概率最高的类别
probabilities = torch.nn.functional.softmax(output[0], dim=0)
predicted_class_idx = torch.argmax(probabilities).item()
# 加载 ImageNet 类别标签
labels = load_labels()
plt.imshow(input_image)
plt.title(f"Predicted id: {predicted_class_idx}, class: {labels[str(predicted_class_idx)][1]}, score: {probabilities[predicted_class_idx].item():.3f}")
plt.show()
if __name__ == '__main__':
main()
输出:
二、目标检测
目标检测是计算机视觉领域中的一项关键任务,其目标是在图像或视频中准确识别和定位物体的位置,并为每个检测到的物体分配相应的类别标签。相较于图像分类,目标检测不仅需要识别物体的类别,还需要提供物体的边界框或轮廓信息。
PyTorch
为目标检测任务同样提供了多种强大的模型选择,从经典的Faster R-CNN
和SSD
到RetinaNet
和Fcos
,用户可以根据任务的需求选择合适的模型。这些模型的开源实现使得目标检测应用变得更加便捷,为各种场景下的实际应用提供了坚实的基础。
Pytorch
中支持的预训练模型在:torchvision.models.detection
下:
基本使用,以下面图像为例:
import torch
# 模型
# from torchvision.models.detection import fasterrcnn_mobilenet_v3_large_fpn
# from torchvision.models.detection import ssdlite320_mobilenet_v3_large
# from torchvision.models.detection import fcos_resnet50_fpn
# from torchvision.models.detection import retinanet_resnet50_fpn_v2
# from torchvision.models.detection import ssd300_vgg16
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from torchvision.transforms import functional as F
from PIL import Image,ImageDraw
from torchvision.ops import nms
import matplotlib.pyplot as plt
def main():
image_path = './images/dogs.jpg'
# 加载预训练的 Faster R-CNN 模型
model = fasterrcnn_resnet50_fpn_v2(pretrained=True)
model.eval()
# 加载图像
image = Image.open(image_path)
# 对图像进行预处理
image_tensor = F.to_tensor(image)
image_tensor = image_tensor.unsqueeze(0) # 增加 batch 维度
# 获取预测结果
with torch.no_grad():
predictions = model(image_tensor)
# 提取预测的边界框、类别和分数
boxes = predictions[0]['boxes'].cpu().numpy()
labels = predictions[0]['labels'].cpu().numpy()
scores = predictions[0]['scores'].cpu().numpy()
# 执行非极大值抑制
keep = nms(torch.tensor(boxes), torch.tensor(scores), iou_threshold=0.5)
# 保留NMS后的结果
boxes = boxes[keep]
labels = labels[keep]
scores = scores[keep]
COCO_INSTANCE_CATEGORY_NAMES=[
'__background__','person','bicycle','car','motorcycle',
'airplane','bus','train','truck','boat','traffic light',
'fire hydrant','N/A','stop sign','parking meter','bench',
'bird','cat','dog','horse','sheep','cow','elephant',
'bear','zebra','giraffe','N/A','backpack','umbrella','N/A',
'N/A','handbag','tie','suitcase','frisbee','skis','snowboard',
'surfboard','tennis racket','bottle','N/A','wine glass',
'cup','fork','knife','spoon','bowl','banana','apple',
'sandwich','orange','broccoli','carrot','hot dog','pizza',
'donut','cake','chair','couch','potted plant','bed','N/A',
'dining table','N/A','N/A','toilet','N/A','tv','laptop',
'mouse','remote','keyboard','cell phone','microwave','oven',
'toaster','sink','refrigerator','N/A','book','clock',
'vase','scissors','teddy bear','hair drier','toothbrush'
]
# 可视化结果
draw = ImageDraw.Draw(image)
for box, label, score in zip(boxes, labels, scores):
if score > 0.5: # 过滤掉低置信度的预测
box = [round(coord, 2) for coord in box]
draw.rectangle(box, outline='red', width=2)
plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], fill=False, color='red', linewidth=3)
draw.text((box[0], box[1]), f"Class {COCO_INSTANCE_CATEGORY_NAMES[label]} ({round(score, 2)})", fill='red')
plt.imshow(image)
plt.show()
if __name__ == '__main__':
main()
输出:
三、关键点检测
关键点检测也是计算机视觉领域中的一项任务,其目标是在图像中检测出具有特殊意义的关键点,这些点通常对于理解图像的结构和语义信息非常重要。关键点可以是人体的关节、面部的特定位置、物体的边缘等。关键点检测在许多领域中都有着广泛的应用,如人体姿态估计、面部识别、手部追踪等。
其中在 torchvision.models.detection
下的 keypointrcnn
系列模型可以预测出人体的关键位置点:
基本使用,以下面图像为例:
import torch
# 模型
from torchvision.models.detection import keypointrcnn_resnet50_fpn
from torchvision.transforms import functional as F
from PIL import Image, ImageDraw
from torchvision.ops import nms
import matplotlib.pyplot as plt
def main():
image_path = './images/people.jpg'
# 加载预训练的 Faster R-CNN 模型
model = keypointrcnn_resnet50_fpn(pretrained=True)
model.eval()
# 加载图像
image = Image.open(image_path)
# 对图像进行预处理
image_tensor = F.to_tensor(image)
image_tensor = image_tensor.unsqueeze(0) # 增加 batch 维度
# 获取预测结果
with torch.no_grad():
predictions = model(image_tensor)
# 提取预测的边界框、类别和分数
boxes = predictions[0]['boxes'].cpu().numpy()
labels = predictions[0]['labels'].cpu().numpy()
scores = predictions[0]['scores'].cpu().numpy()
keypoints = predictions[0]['keypoints']
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle',
'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench',
'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant',
'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A',
'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A',
'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock',
'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
# 可视化结果
draw = ImageDraw.Draw(image)
for box, label, score, keypoints in zip(boxes, labels, scores, keypoints):
if score > 0.5: # 过滤掉低置信度的预测
box = [round(coord, 2) for coord in box]
draw.rectangle(box, outline='red', width=2)
plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], fill=False, color='red', linewidth=3)
draw.text((box[0], box[1]), f"Class {COCO_INSTANCE_CATEGORY_NAMES[label]} ({round(score, 2)})", fill='red')
for i in range(keypoints.shape[0]):
x = keypoints[i, 0]
y = keypoints[i, 1]
visi = keypoints[i, 2]
if visi > 0:
draw.ellipse(xy=(x - 3, y - 3, x + 3, y + 3), fill=(255, 0, 0))
texts = str(i + 1)
draw.text((x + 5, y - 5), texts, fill='red')
plt.imshow(image)
plt.show()
if __name__ == '__main__':
main()
四、语义分割
语义分割是计算机视觉领域中的一项重要任务,旨在对图像中的每个像素进行分类,从而将图像划分为不同的语义区域。与目标检测不同,语义分割不仅标识出图像中存在的物体,还精确地为图像中的每个像素分配相应的语义标签,从而实现像素级别的分类。
语义分割的主要目标是理解图像中的语义结构,使计算机能够对图像中的不同区域进行更深入的理解和分析。这项任务在许多计算机视觉应用中都具有广泛的应用,例如自动驾驶、医学图像分析、图像编辑等。
Pytorch 中支持的语义分割预训练模型在:torchvision.models.segmentation
下:
基本使用,以下面图像为例:
import torch
import torchvision.transforms as transforms
from PIL import Image
from torchvision import models
import matplotlib.pyplot as plt
def main():
image_path = "./images/dogs.jpg"
# 加载预训练的 DeepLabV3 模型
model = models.segmentation.deeplabv3_resnet50(pretrained=True)
# 设置模型为评估模式
model.eval()
# 图像预处理
input_image = Image.open(image_path)
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # 添加 batch 维度
# 将输入传递给模型并获取预测结果
with torch.no_grad():
output = model(input_batch)['out'][0]
output_predictions = output.argmax(0) # 获取预测结果中的类别索引
# 将类别索引转换为可视化的结果
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
colors = (colors % 255).numpy().astype("uint8")
r = Image.fromarray(output_predictions.byte().cpu().numpy()).resize(input_image.size)
r.putpalette(colors)
plt.imshow(r)
plt.show()
if __name__ == '__main__':
main()
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!