RT-DETR模型导出与推理
2023-12-26 22:23:33
1.准备工作
RT-DETR模型训练可参考:http://t.csdnimg.cn/Fsph5
模型导出与模型推理需要安装onnx库和onnxruntime库
可通过以下命令安装:
pip install onnx -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple
2.onnx模型导出
首先找到export_onnx.py文件,该文件位于RT-DETR/RT-DETR-main/rtdetr_pytorch/tools/export_onnx.py
然后修改config与resume参数,使其路径为你的具体路径。
其中resume参数为训练生成的pth权重文件路径。
"""by lyuwenyu
"""
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
import argparse
import numpy as np
from src.core import YAMLConfig
import torch
import torch.nn as nn
def main(args, ):
"""main
"""
cfg = YAMLConfig(args.config, resume=args.resume)
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
if 'ema' in checkpoint:
state = checkpoint['ema']['module']
else:
state = checkpoint['model']
else:
raise AttributeError('only support resume to load model.state_dict by now.')
# NOTE load train mode state -> convert to deploy mode
cfg.model.load_state_dict(state)
class Model(nn.Module):
def __init__(self, ) -> None:
super().__init__()
self.model = cfg.model.deploy()
self.postprocessor = cfg.postprocessor.deploy()
print(self.postprocessor.deploy_mode)
def forward(self, images, orig_target_sizes):
outputs = self.model(images)
return self.postprocessor(outputs, orig_target_sizes)
model = Model()
dynamic_axes = {
'images': {0: 'N', },
'orig_target_sizes': {0: 'N'}
}
data = torch.rand(1, 3, 640, 640)
size = torch.tensor([[640, 640]])
torch.onnx.export(
model,
(data, size),
args.file_name,
input_names=['images', 'orig_target_sizes'],
output_names=['labels', 'boxes', 'scores'],
dynamic_axes=dynamic_axes,
opset_version=16,
verbose=False
)
if args.check:
import onnx
onnx_model = onnx.load(args.file_name)
onnx.checker.check_model(onnx_model)
print('Check export onnx model done...')
if args.simplify:
import onnxsim
dynamic = True
input_shapes = {'images': data.shape, 'orig_target_sizes': size.shape} if dynamic else None
onnx_model_simplify, check = onnxsim.simplify(args.file_name, input_shapes=input_shapes, dynamic_input_shape=dynamic)
onnx.save(onnx_model_simplify, args.file_name)
print(f'Simplify onnx model {check}...')
# import onnxruntime as ort
# from PIL import Image, ImageDraw
# from torchvision.transforms import ToTensor
# # print(onnx.helper.printable_graph(mm.graph))
# im = Image.open('./000000014439.jpg').convert('RGB')
# im = im.resize((640, 640))
# im_data = ToTensor()(im)[None]
# print(im_data.shape)
# sess = ort.InferenceSession(args.file_name)
# output = sess.run(
# # output_names=['labels', 'boxes', 'scores'],
# output_names=None,
# input_feed={'images': im_data.data.numpy(), "orig_target_sizes": size.data.numpy()}
# )
# # print(type(output))
# # print([out.shape for out in output])
# labels, boxes, scores = output
# draw = ImageDraw.Draw(im)
# thrh = 0.6
# for i in range(im_data.shape[0]):
# scr = scores[i]
# lab = labels[i][scr > thrh]
# box = boxes[i][scr > thrh]
# print(i, sum(scr > thrh))
# for b in box:
# draw.rectangle(list(b), outline='red',)
# draw.text((b[0], b[1]), text=str(lab[i]), fill='blue', )
# im.save('test.jpg')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-c', type=str,default = "/home/guan/RT-DETR/RT-DETR-main/rtdetr_pytorch/configs/rtdetr/rtdetr_r18vd_6x_coco.yml" )
parser.add_argument('--resume', '-r', type=str,default = "/home/guan/RT-DETR/RT-DETR-main/rtdetr_pytorch/tools/output/rtdetr_r18vd_6x_coco/checkpoint0012.pth" )
parser.add_argument('--file-name', '-f', type=str, default='model.onnx')
parser.add_argument('--check', action='store_true', default=False,)
parser.add_argument('--simplify', action='store_true', default=False,)
args = parser.parse_args()
main(args)
修改完毕后,即可运行export_onnx.py,生成的onnx文件位于该py文件的同级目录。
3.推理
在tools文件夹下创建mypredict.py
mypredict.py的代码如下:
你需要修改img_path ,使其为你推理所需的图像路径。
img.save()中的路径修改为你的推理结果保存的路径。
-------2023.12.21更新--------
按照你的数据集中的类别修改classes
import torch
import onnxruntime as ort
from PIL import Image, ImageDraw
from torchvision.transforms import ToTensor
if __name__ == "__main__":
##################
classes = ['','LicensePlate']
##################
# print(onnx.helper.printable_graph(mm.graph))
#############
img_path = "/home/guan/RT-DETR/RT-DETR-main/rtdetr_pytorch/tools/input/IMG_8669.jpg"
#############
im = Image.open(img_path).convert('RGB')
im = im.resize((640, 640))
im_data = ToTensor()(im)[None]
print(im_data.shape)
size = torch.tensor([[640, 640]])
sess = ort.InferenceSession("model.onnx")
output = sess.run(
# output_names=['labels', 'boxes', 'scores'],
output_names=None,
input_feed={'images': im_data.data.numpy(), "orig_target_sizes": size.data.numpy()}
)
# print(type(output))
# print([out.shape for out in output])
labels, boxes, scores = output
draw = ImageDraw.Draw(im)
thrh = 0.6
for i in range(im_data.shape[0]):
scr = scores[i]
lab = labels[i][scr > thrh]
box = boxes[i][scr > thrh]
print(i, sum(scr > thrh))
#print(lab)
print(f'box:{box}')
for l, b in zip(lab, box):
draw.rectangle(list(b), outline='red',)
print(l.item())
draw.text((b[0], b[1] - 10), text=str(classes[l.item()]), fill='blue', )
#############
im.save('/home/guan/RT-DETR/RT-DETR-main/rtdetr_pytorch/tools/output/predict/res.jpg')
#############
运行mypredict.py,得到推理结果。
文章来源:https://blog.csdn.net/weixin_53895623/article/details/135128497
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!