yolov7 onnx推理

2023-12-25 16:37:16

环境介绍

  1. 用的mmyolo0.6.0导出的yolov7,如果导出模型有问题,参考我之前的文章
  2. 我的数据集只有两个类别。
  3. 下图是我模型的输入输出。
    在这里插入图片描述

因为需要在边缘设备推理,所以先用python写一遍onnx的后处理。代码如下:

import onnxruntime
import numpy as np
import cv2
import os
import tqdm
from torchvision.ops import nms
import torch
# 指定 ONNX 模型文件路径
onnx_model_path = 'work_dir/best_coco_bbox_mAP_epoch_190.onnx'
sess = onnxruntime.InferenceSession(onnx_model_path)

anchor8 = np.loadtxt('anchor_60_240.txt')
anchor16 = np.loadtxt('anchor_30_120.txt')
anchor32 = np.loadtxt('anchor_15_60.txt')
strides = [8,16,32]

coco_classes = [
    'person', 'soccer'
]
def softmax(x):
    exp_x = np.exp(x - np.max(x))  # 避免指数溢出
    return exp_x / exp_x.sum(axis=0, keepdims=True)

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def get_color(idx):
    idx += 3
    return (37 * idx % 255, 17 * idx % 255, 29 * idx % 255)

def onnx_infer(img_path):
    # 创建 ONNX 运行时的 Session
    
    # 构造输入数据
    img = cv2.imread(img_path)
    img = cv2.resize(img,(1920,480))
    timg = img[...,::-1]
    input_data = np.transpose(timg,(2,0,1))[None,...]
    
    input_data = np.array(input_data/255,dtype=np.float32)

    # 进行推理
    output_data = sess.run(['617','619','621'], {'images': input_data})
       
    ans_bboxs = []
    ans_score = []
    ans_cat = []
    
    for output,anchor,stride in zip(output_data,(anchor8,anchor16,anchor32),strides):
        _,h,w,c = output.shape
        pred = output.reshape(-1,c)
   
        for i in range(h*w):
            
            pix = pred[i]
            shifty = (i//w)*stride
            shiftx = (i%w)*stride        
            cur_anchor = anchor+[shiftx,shifty,shiftx,shifty]      
            
            
            for k in range(3):
                xc = (cur_anchor[k][0]+cur_anchor[k][2])/2
                yc = (cur_anchor[k][1]+cur_anchor[k][3])/2
                wc = cur_anchor[k][2] - cur_anchor[k][0]
                hc = cur_anchor[k][3] - cur_anchor[k][1]
                xc_pred = (sigmoid(pix[k*7+0]) - 0.5)*2*stride+xc
                yc_pred = (sigmoid(pix[k*7+1]) - 0.5)*2*stride+yc
                w_pred = (sigmoid(pix[k*7+2])*2)**2*wc
                h_pred = (sigmoid(pix[k*7+3])*2)**2*hc
                bbox_score = sigmoid(pix[k*7+4])
                
                if pix[k*7+5] > pix[k*7+6]:
                    cat_score = pix[k*7+5]
                    cat_ = 0
                else:
                    cat_score = pix[k*7+6]
                    cat_ = 1

                conf = sigmoid(cat_score)*bbox_score
                ltx = xc_pred - w_pred/2
                lty = yc_pred - h_pred/2
                rtx = xc_pred + w_pred/2
                rty = yc_pred + h_pred/2
                
                if bbox_score > 0.1:
                    ans_bboxs.append([ltx,lty,rtx,rty])
                    ans_score.append(conf)
                    ans_cat.append(cat_)
             
            
    # indices = cv2.dnn.NMSBoxes(ans_bboxs, ans_score,score_threshold=0.1, nms_threshold=0.55)
    keep = nms(torch.Tensor(ans_bboxs), torch.Tensor(ans_score), iou_threshold=0.5)
    ult_bbox = np.array(ans_bboxs)[keep]
    ult_score = np.array(ans_score)[keep]
    ult_cat = np.array(ans_cat)[keep]


    for bbox,conf,cat in zip(ult_bbox,ult_score,ult_cat):
        x0,y0,x1,y1 = bbox
        cv2.rectangle(img,(int(x0),int(y0)),(int(x1),int(y1)),get_color(int(cat)),2,2)
        label = f'{coco_classes[int(cat)]}: {conf:.2f}'
        cv2.putText(img, label, (int(x0),int(y0) - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, get_color(int(cat)), 2)
    cv2.imwrite(os.path.join('onnx_infer',img_path.split('/')[-1]),img)
    
    

if __name__ == '__main__':
    data_root = 'mydemo/'
    imgs = os.listdir(data_root)
    for img in tqdm.tqdm(imgs):
        if img.endswith('jpg') or img.endswith('png'):
            onnx_infer(os.path.join(data_root,img))

anchor是在下面这个文件保存的:

在这里插入图片描述

文章来源:https://blog.csdn.net/qq_33596242/article/details/135202147
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。