盘点 Pytorch Vision 中的图像预训练模型

2023-12-16 17:01:53

PyTorch Vision 库提供了许多经过预训练的视觉模型,包括图像分类、目标检测、语义分割等。

一、图像分类

图像分类是计算机视觉领域中最基础且常见的任务之一,其目标是将图像分配到预定义的类别中。在图像分类任务中,计算机模型需要学习从输入图像中提取特征并做出正确分类的能力。这一任务在许多应用中都是至关重要的,例如图像检索、物体识别、人脸识别等。

PyTorch Vision库为图像分类任务提供了丰富的模型选择,从经典的AlexNetVGG到深度残差网络ResNet、密集连接网络DenseNet,以及轻量级网络SqueezeNet,用户可以根据任务需求选择适当的模型。

Pytorch 中支持的预训练模型在:torchvision.models 下:

在这里插入图片描述

基本使用,以下面图像为例:

在这里插入图片描述

import torch
import torchvision.transforms as transforms
from PIL import Image
from torchvision import models
import requests
import matplotlib.pyplot as plt

def load_labels():
    url = 'https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json'
    response = requests.get(url)
    if response.status_code == 200:
        return response.json()
    return []

def main():
    image_path = "./images/cat.jpg"
    # 加载预训练的 ResNet-50 模型
    model = models.resnet50(pretrained=True)
    model.eval()
    # 图像预处理
    input_image = Image.open(image_path)
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)  # 添加 batch 维度

    # 将输入传递给模型并获取预测结果
    with torch.no_grad():
        output = model(input_batch)

    # 获取预测结果中概率最高的类别
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    predicted_class_idx = torch.argmax(probabilities).item()

    # 加载 ImageNet 类别标签
    labels = load_labels()

    plt.imshow(input_image)
    plt.title(f"Predicted id: {predicted_class_idx}, class: {labels[str(predicted_class_idx)][1]}, score: {probabilities[predicted_class_idx].item():.3f}")
    plt.show()

if __name__ == '__main__':
    main()

输出:
在这里插入图片描述

二、目标检测

目标检测是计算机视觉领域中的一项关键任务,其目标是在图像或视频中准确识别和定位物体的位置,并为每个检测到的物体分配相应的类别标签。相较于图像分类,目标检测不仅需要识别物体的类别,还需要提供物体的边界框或轮廓信息。

PyTorch为目标检测任务同样提供了多种强大的模型选择,从经典的Faster R-CNNSSDRetinaNetFcos,用户可以根据任务的需求选择合适的模型。这些模型的开源实现使得目标检测应用变得更加便捷,为各种场景下的实际应用提供了坚实的基础。

Pytorch 中支持的预训练模型在:torchvision.models.detection 下:

在这里插入图片描述

基本使用,以下面图像为例:

在这里插入图片描述

import torch
# 模型
# from torchvision.models.detection import fasterrcnn_mobilenet_v3_large_fpn
# from torchvision.models.detection import ssdlite320_mobilenet_v3_large
# from torchvision.models.detection import fcos_resnet50_fpn
# from torchvision.models.detection import retinanet_resnet50_fpn_v2
# from torchvision.models.detection import ssd300_vgg16
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from torchvision.transforms import functional as F
from PIL import Image,ImageDraw
from torchvision.ops import nms
import matplotlib.pyplot as plt


def main():
    image_path = './images/dogs.jpg'

    # 加载预训练的 Faster R-CNN 模型
    model = fasterrcnn_resnet50_fpn_v2(pretrained=True)
    model.eval()

    # 加载图像
    image = Image.open(image_path)

    # 对图像进行预处理
    image_tensor = F.to_tensor(image)
    image_tensor = image_tensor.unsqueeze(0)  # 增加 batch 维度

    # 获取预测结果
    with torch.no_grad():
        predictions = model(image_tensor)

    # 提取预测的边界框、类别和分数
    boxes = predictions[0]['boxes'].cpu().numpy()
    labels = predictions[0]['labels'].cpu().numpy()
    scores = predictions[0]['scores'].cpu().numpy()

    # 执行非极大值抑制
    keep = nms(torch.tensor(boxes), torch.tensor(scores), iou_threshold=0.5)

    # 保留NMS后的结果
    boxes = boxes[keep]
    labels = labels[keep]
    scores = scores[keep]

    COCO_INSTANCE_CATEGORY_NAMES=[
        '__background__','person','bicycle','car','motorcycle',
        'airplane','bus','train','truck','boat','traffic light',
        'fire hydrant','N/A','stop sign','parking meter','bench',
        'bird','cat','dog','horse','sheep','cow','elephant',
        'bear','zebra','giraffe','N/A','backpack','umbrella','N/A',
        'N/A','handbag','tie','suitcase','frisbee','skis','snowboard',
        'surfboard','tennis racket','bottle','N/A','wine glass',
        'cup','fork','knife','spoon','bowl','banana','apple',
        'sandwich','orange','broccoli','carrot','hot dog','pizza',
        'donut','cake','chair','couch','potted plant','bed','N/A',
        'dining table','N/A','N/A','toilet','N/A','tv','laptop',
        'mouse','remote','keyboard','cell phone','microwave','oven',
        'toaster','sink','refrigerator','N/A','book','clock',
        'vase','scissors','teddy bear','hair drier','toothbrush'
    ]
    # 可视化结果
    draw = ImageDraw.Draw(image)
    for box, label, score in zip(boxes, labels, scores):
        if score > 0.5:  # 过滤掉低置信度的预测
            box = [round(coord, 2) for coord in box]
            draw.rectangle(box, outline='red', width=2)
            plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], fill=False, color='red', linewidth=3)
            draw.text((box[0], box[1]), f"Class {COCO_INSTANCE_CATEGORY_NAMES[label]} ({round(score, 2)})", fill='red')

    plt.imshow(image)
    plt.show()

if __name__ == '__main__':
    main()

输出:

在这里插入图片描述

三、关键点检测

关键点检测也是计算机视觉领域中的一项任务,其目标是在图像中检测出具有特殊意义的关键点,这些点通常对于理解图像的结构和语义信息非常重要。关键点可以是人体的关节、面部的特定位置、物体的边缘等。关键点检测在许多领域中都有着广泛的应用,如人体姿态估计、面部识别、手部追踪等。

其中在 torchvision.models.detection 下的 keypointrcnn 系列模型可以预测出人体的关键位置点:

基本使用,以下面图像为例:

在这里插入图片描述

import torch
# 模型
from torchvision.models.detection import keypointrcnn_resnet50_fpn
from torchvision.transforms import functional as F
from PIL import Image, ImageDraw
from torchvision.ops import nms
import matplotlib.pyplot as plt


def main():
    image_path = './images/people.jpg'

    # 加载预训练的 Faster R-CNN 模型
    model = keypointrcnn_resnet50_fpn(pretrained=True)
    model.eval()

    # 加载图像
    image = Image.open(image_path)

    # 对图像进行预处理
    image_tensor = F.to_tensor(image)
    image_tensor = image_tensor.unsqueeze(0)  # 增加 batch 维度

    # 获取预测结果
    with torch.no_grad():
        predictions = model(image_tensor)

    # 提取预测的边界框、类别和分数
    boxes = predictions[0]['boxes'].cpu().numpy()
    labels = predictions[0]['labels'].cpu().numpy()
    scores = predictions[0]['scores'].cpu().numpy()
    keypoints = predictions[0]['keypoints']

    COCO_INSTANCE_CATEGORY_NAMES = [
        '__background__', 'person', 'bicycle', 'car', 'motorcycle',
        'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
        'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench',
        'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant',
        'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A',
        'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
        'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
        'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
        'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
        'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A',
        'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop',
        'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
        'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock',
        'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
    ]

    # 可视化结果
    draw = ImageDraw.Draw(image)
    for box, label, score, keypoints in zip(boxes, labels, scores, keypoints):
        if score > 0.5:  # 过滤掉低置信度的预测
            box = [round(coord, 2) for coord in box]
            draw.rectangle(box, outline='red', width=2)
            plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], fill=False, color='red', linewidth=3)
            draw.text((box[0], box[1]), f"Class {COCO_INSTANCE_CATEGORY_NAMES[label]} ({round(score, 2)})", fill='red')

            for i in range(keypoints.shape[0]):
                x = keypoints[i, 0]
                y = keypoints[i, 1]
                visi = keypoints[i, 2]
                if visi > 0:
                    draw.ellipse(xy=(x - 3, y - 3, x + 3, y + 3), fill=(255, 0, 0))
                    texts = str(i + 1)
                    draw.text((x + 5, y - 5), texts, fill='red')


    plt.imshow(image)
    plt.show()


if __name__ == '__main__':
    main()

在这里插入图片描述

四、语义分割

语义分割是计算机视觉领域中的一项重要任务,旨在对图像中的每个像素进行分类,从而将图像划分为不同的语义区域。与目标检测不同,语义分割不仅标识出图像中存在的物体,还精确地为图像中的每个像素分配相应的语义标签,从而实现像素级别的分类。

语义分割的主要目标是理解图像中的语义结构,使计算机能够对图像中的不同区域进行更深入的理解和分析。这项任务在许多计算机视觉应用中都具有广泛的应用,例如自动驾驶、医学图像分析、图像编辑等。

Pytorch 中支持的语义分割预训练模型在:torchvision.models.segmentation 下:

在这里插入图片描述

基本使用,以下面图像为例:

在这里插入图片描述

import torch
import torchvision.transforms as transforms
from PIL import Image
from torchvision import models
import matplotlib.pyplot as plt


def main():
    image_path = "./images/dogs.jpg"
    # 加载预训练的 DeepLabV3 模型
    model = models.segmentation.deeplabv3_resnet50(pretrained=True)

    # 设置模型为评估模式
    model.eval()

    # 图像预处理
    input_image = Image.open(image_path)
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)  # 添加 batch 维度

    # 将输入传递给模型并获取预测结果
    with torch.no_grad():
        output = model(input_batch)['out'][0]
    output_predictions = output.argmax(0)  # 获取预测结果中的类别索引

    # 将类别索引转换为可视化的结果
    palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
    colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
    colors = (colors % 255).numpy().astype("uint8")

    r = Image.fromarray(output_predictions.byte().cpu().numpy()).resize(input_image.size)
    r.putpalette(colors)

    plt.imshow(r)
    plt.show()

if __name__ == '__main__':
    main()

在这里插入图片描述

文章来源:https://blog.csdn.net/qq_43692950/article/details/135032533
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。