图像识别快速实现
2024-01-03 19:05:46
文本的跑通了,接下来玩玩图片场景
1. 引入模型
再另起类test_qdrant_img.py,转化图片用到的模型和文本不太一样,我们这里使用ResNet-50模型
import unittest
from qdrant_client.http.models import Distance, VectorParams
from qdrant_client import QdrantClient
import torch
import torchvision.transforms as transforms
from PIL import Image
class TestQDrantImg(unittest.TestCase):
def setUp(self):
self.collection_name = "img_collection"
self.client = QdrantClient("localhost", port=6333)
# 加载ResNet-50模型
self.model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
self.model.eval()
# 图像预处理
self.preprocess = transforms.Compose([
# 图像调整为256*256
transforms.Resize(256),
# 中心裁剪为224*224
transforms.CenterCrop(224),
# 转换为张量,像素值从范围[0,255]缩放到范围[0,1],RGB(红绿蓝)转换为通道顺序(即 RGB 顺序)
transforms.ToTensor(),
# 应用归一化,减去均值(mean)并除以标准差(std)
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
2. 添加图片向量
我们先创建一个新集合
def test_create_collection(self):
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=1000, distance=Distance.EUCLID),
)
往集合里分别添加1个猫的图片和1个狗的图片
def test_img_vector(self):
# 加载并预处理图像
id = 1
image_path = './img/cat1.png'
# id = 2
# image_path = './img/dog1.png'
image = Image.open(image_path)
image_tensor = self.preprocess(image)
# 在第0维度上添加一个维度,将图像张量转换为形状为 (1, C, H, W) 的张量,其中 C 是通道数,H 是高度,W 是宽度
image_tensor = torch.unsqueeze(image_tensor, 0)
with torch.no_grad():
# 去除维度为1的维度,将特征向量的形状从 (1, D) 转换为 (D,)
feature_vector = self.model(image_tensor).squeeze().tolist()
operation_info = self.client.upsert(
collection_name=self.collection_name,
points=[{'id': id, 'vector': feature_vector, 'payload': {"image_path": image_path}}]
)
print(operation_info)
3. 匹配图片向量
然后用其他猫狗的图片来做搜索匹配
def test_search(self):
# 加载并预处理图像
image_path = './img/cat2.png'
# image_path = './img/dog2.png'
# image_path = './img/cat3.png'
image = Image.open(image_path)
image_tensor = self.preprocess(image)
image_tensor = torch.unsqueeze(image_tensor, 0)
with torch.no_grad():
feature_vector = self.model(image_tensor).squeeze().tolist()
search_result = self.client.search(
collection_name=self.collection_name, query_vector=feature_vector, limit=3
, with_vectors=True, with_payload=True
)
print(search_result)
结果:
[ScoredPoint(id = 1, version = 0, score = 68.21013, payload = {
'image_path': './img/cat1.png'
}, vector = [...]),
ScoredPoint(id = 2, version = 1, score = 85.10757, payload = {
'image_path': './img/dog1.png'
}, vector = [...])]
当使用猫2猫3作为查询条件时,跟猫1记录的score(向量距离)较小;
同理,使用狗2作为查询条件时,跟狗1记录的score(向量距离)较小
文章来源:https://blog.csdn.net/cxs812760493/article/details/135346484
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!