RTDETR模型一键训练/预测(执行train.sh与detect.sh)
文章目录
引言
本文章基于客户一键训练与测试需求,我使用u公司的yolov8集成的RTDETR模型改成较为保姆级的一键
操作的训练/预测方式,也特别适合新手或想偷懒转换数据格式的朋友们。本文一键体现数据格式为图像与xml,调用train.sh与detect.sh可完成模型的训练与预测。而为完成该操作,模型内嵌入xml转RTDETR的txt格式、自动分配训练/验证集、自动切换环境等内容。接下来,我将介绍如何操作,并附修改源码。
源码链接:我已上传个人资源,请自行下载!
一、配置参数设置
该文件是RTDETR数据转换配置和模型使用参数,被我修改满足一键训练与测试文件的配置参数。包含将图像与xml文件数据格式转为模型训练格式数据,只需要提供xml与图像文件夹
,可完成数据转换,详情如下:
# 设置img与xml的文件路径,也可为同一个文件,按照xml选择img
img_path: C:/Users/Administrator/Desktop/rtdetr/example_template/data #
xml_path: C:/Users/Administrator/Desktop/rtdetr/example_template/data
# 设置数据集训练与验证集测试的比率,和小于1,通常test比率不设置为0
train_rate: 0.8
val_rate: 0.2
test_rate:
path: C:/Users/Administrator/Desktop/rtdetr/example_template/rtdert_data # 必填,转换存放数据集文件夹,必须设置
train: images/train # 不设置
val: images/val # 不设置
test:
# Classes
names:
0: person
1: bicycle
2: car
3: motorcycle
4: airplane
5: bus
6: train
7: truck
8: boat
9: traffic light
10: fire hydrant
11: stop sign
12: parking meter
13: bench
14: bird
15: cat
16: dog
17: horse
18: sheep
19: cow
20: elephant
21: bear
22: zebra
23: giraffe
24: backpack
25: umbrella
26: handbag
27: tie
28: suitcase
29: frisbee
30: skis
31: snowboard
32: sports ball
33: kite
34: baseball bat
35: baseball glove
36: skateboard
37: surfboard
38: tennis racket
39: bottle
40: wine glass
41: cup
42: fork
43: knife
44: spoon
45: bowl
46: banana
47: apple
48: sandwich
49: orange
50: broccoli
51: carrot
52: hot dog
53: pizza
54: donut
55: cake
56: chair
57: couch
58: potted plant
59: bed
60: dining table
61: toilet
62: tv
63: laptop
64: mouse
65: remote
66: keyboard
67: cell phone
68: microwave
69: oven
70: toaster
71: sink
72: refrigerator
73: book
74: clock
75: vase
76: scissors
77: teddy bear
78: hair drier
79: toothbrush
二、数据格式转换代码
该文件代码提供了xml格式转rtdetr模型需要格式,基本是属于逻辑,代码能力较为基础,我不在介绍,代码如下:
import pandas as pd
import cv2
from tqdm import tqdm
import os
import numpy as np
import json
import xml.etree.ElementTree as ET
from lxml.etree import Element, SubElement, tostring, ElementTree
from xml.dom.minidom import parseString
import random
import shutil
import yaml
img_format = ['.jpg', '.png', '.bmp']
def build_dir(root):
import os
if not os.path.exists(root):
os.makedirs(root)
return root
def del_dir(root):
import os
if os.path.exists(root):
shutil.rmtree(root)
return root
############################################生成xml方法##########################
def product_xml(name_img, boxes, codes, img=None, wh=None):
'''
:param img: 以读好的图片
:param name_img: 图片名字,如'xxx.jpg'
:param boxes: box为列表
:param codes: 为列表
:return:
'''
if img is not None:
width = img.shape[0]
height = img.shape[1]
else:
assert wh is not None
width = wh[0]
height = wh[1]
node_root = Element('annotation')
node_folder = SubElement(node_root, 'folder')
node_folder.text = 'VOC2007'
node_filename = SubElement(node_root, 'filename')
node_filename.text = name_img # 图片名字
node_size = SubElement(node_root, 'size')
node_width = SubElement(node_size, 'width')
node_width.text = str(height)
node_height = SubElement(node_size, 'height')
node_height.text = str(width)
node_depth = SubElement(node_size, 'depth')
node_depth.text = '3'
for i, code in enumerate(codes):
box = [boxes[i][0], boxes[i][1], boxes[i][2], boxes[i][3]]
node_object = SubElement(node_root, 'object')
node_name = SubElement(node_object, 'name')
node_name.text = code
node_difficult = SubElement(node_object, 'difficult')
node_difficult.text = '0'
node_bndbox = SubElement(node_object, 'bndbox')
node_xmin = SubElement(node_bndbox, 'xmin')
node_xmin.text = str(int(box[0]))
node_ymin = SubElement(node_bndbox, 'ymin')
node_ymin.text = str(int(box[1]))
node_xmax = SubElement(node_bndbox, 'xmax')
node_xmax.text = str(int(box[2]))
node_ymax = SubElement(node_bndbox, 'ymax')
node_ymax.text = str(int(box[3]))
xml = tostring(node_root, pretty_print=True) # 格式化显示,该换行的换行
dom = parseString(xml)
name = name_img[:-4] + '.xml'
tree = ElementTree(node_root)
print('name:{},dom:{}'.format(name, dom))
return tree, name
def product_xml_demo():
'''
通过box与cat信息为图片产生xml文件
'''
img_root = r'C:\Users\Administrator\Desktop\123\1.jpg'
write_img_name = 'hhhaaa.jpg'
bboxes_lst = [[22, 32, 46, 89]]
cat_lst = ['cat']
img = cv2.imread(img_root)
tree, xml_name = product_xml(write_img_name, bboxes_lst, cat_lst, img=img)
tree.write(os.path.join('./', xml_name))
############################################xml转yolo的txt##########################
def read_xml(xml_root):
'''
:param xml_root: .xml文件
:return: dict('cat':['cat1',...],'bboxes':[[x1,y1,x2,y2],...],'whd':[w ,h,d])
'''
dict_info = {'cat': [], 'bboxes': [], 'box_wh': [], 'whd': []}
if os.path.splitext(xml_root)[-1] == '.xml':
tree = ET.parse(xml_root) # ET是一个xml文件解析库,ET.parse()打开xml文件。parse--"解析"
root = tree.getroot() # 获取根节点
whd = root.find('size')
whd = [whd.find('width').text, whd.find('height').text, whd.find('depth').text]
for obj in root.findall('object'): # 找到根节点下所有“object”节点
cat = str(obj.find('name').text) # 找到object节点下name子节点的值(字符串)
bbox = obj.find('bndbox')
x1, y1, x2, y2 = [int(bbox.find('xmin').text),
int(bbox.find('ymin').text),
int(bbox.find('xmax').text),
int(bbox.find('ymax').text)]
b_w = x2 - x1 + 1
b_h = y2 - y1 + 1
dict_info['cat'].append(cat)
dict_info['bboxes'].append([x1, y1, x2, y2])
dict_info['box_wh'].append([b_w, b_h])
dict_info['whd'].append(whd)
else:
print('[inexistence]:{} suffix is not xml '.format(xml_root))
return dict_info
def write_txt(text_lst, out_txt=None):
'''
每行内容为列表,将其写入text中
'''
out_dir = out_txt if out_txt is not None else 'classes.txt'
file_write_obj = open(out_dir, 'w', encoding='utf-8') # 以写的方式打开文件,如果文件不存在,就会自动创建
for text in text_lst:
file_write_obj.writelines(str(text))
file_write_obj.write('\n')
file_write_obj.close()
def xml2yolotxt(xml_root, img_root=None, save_txt=None, labels_name_lst=None):
'''
:param xml_root: xml的路径
:param img_root:图像路径,可提供也可不提供,提供主要获得图像的高宽
:param out_file:保存txt路径的文件夹
:param labels_name_lst:提供训练列表,xml中出现类别与列表对应,如['pedes', 'elec', 'car', 'truck', 'bus', 'tricycle']
pedes表示0,elec表示1,car表示2等
:return:
'''
if labels_name_lst is None:
raise ValueError("lack labels list ")
if save_txt is None:
raise ValueError("lack saving root for txt file ")
xml_info = read_xml(xml_root)
if img_root is not None:
# 从中提取W与H
img = cv2.imread(img_root)
H, W = img.shape[:2]
else:
whd = xml_info['whd'][0]
W, H = float(whd[0]), float(whd[1])
boxes_lst = xml_info['bboxes']
labels_lst = xml_info['cat']
yolotxt_lst = []
for i, b in enumerate(boxes_lst):
label = labels_lst[i]
if label in labels_name_lst:
label_idx = list(labels_name_lst).index(label)
bw, bh = b[2] - b[0], b[3] - b[1]
x, y = b[0] + bw / 2, b[1] + bh / 2
x, y, w, h = x / W, y / H, bw / W, bh / H
# yolotxt = str(cat_lst[i]) + ' ' + str(x) + ' ' + str(y) + ' ' + str(w) + ' ' + str(h)
yolotxt = str(label_idx) + ' ' + str(x) + ' ' + str(y) + ' ' + str(w) + ' ' + str(h)
yolotxt_lst.append(yolotxt)
if len(yolotxt_lst) > 0:
write_txt(yolotxt_lst, save_txt)
def convert_data_train(xml_path, img_path, out_file_path, labels_name_lst, **kwargs):
'''
xml_path:xml文件夹的路径
img_path:图片文件夹的路径
out_file_path:模型训练的文件夹,用于yolo模型训练
labels_name_lst:标签列表,模型只转换与训练的标签列表
kwargs:其它参数
'''
print('\n convert data...')
img_suffix = kwargs.get('img_suffix') if kwargs.get('img_suffix') else 4
img_names = [name for name in os.listdir(img_path) if name[-4:] in img_format]
img_names_no_suffix = [name[:-img_suffix] for name in img_names]
xml_names_temp = [name for name in os.listdir(xml_path) if name[-3:] == 'xml']
N = len(xml_names_temp)
N_idx = [i for i in range(N)]
random.shuffle(N_idx)
xml_names = [xml_names_temp[i] for i in N_idx]
train_N = N * kwargs.get('train_rate') if kwargs.get('train_rate') else 0.7 * N
val_N = N * kwargs.get('val_rate') if kwargs.get('val_rate') else 0.3 * N
test_N = N * kwargs.get('test_rate') if kwargs.get('test_rate') else 0
if (train_N / N + val_N / N + test_N / N) > 1:
raise ValueError(
"rate of datasets error,sum>1, train_rate:{}\tval_rate:{}\ttest_rate{}".format(train_N / N, val_N / N,
test_N / N))
# 构建训练文件
images_path = os.path.join(out_file_path, 'images')
labels_path = os.path.join(out_file_path, 'labels')
del_dir(images_path)
del_dir(labels_path)
build_dir(images_path)
build_dir(labels_path)
train_img_path = build_dir(os.path.join(images_path, 'train'))
val_img_path = build_dir(os.path.join(images_path, 'val'))
test_img_path = build_dir(os.path.join(images_path, 'test'))
train_label_path = build_dir(os.path.join(labels_path, 'train'))
val_label_path = build_dir(os.path.join(labels_path, 'val'))
test_label_path = build_dir(os.path.join(labels_path, 'test'))
problem_xmls=[]
for i in tqdm(range(int(train_N))):
xml_name = xml_names[i]
xml_root = os.path.join(xml_path, xml_name)
if xml_name[:-4] in list(img_names_no_suffix):
img_idx = list(img_names_no_suffix).index(xml_name[:-4])
img_name = img_names[img_idx]
img_root = os.path.join(img_path, img_name)
save_txt = os.path.join(train_label_path, xml_name[:-3] + 'txt')
try:
xml2yolotxt(xml_root, img_root=img_root, save_txt=save_txt, labels_name_lst=labels_name_lst)
except:
problem_xmls.append(xml_root)
break
shutil.copy(img_root, os.path.join(train_img_path, img_name))
print('\nfinishing vonvert of train data,train_rate:\t{}\t train count:\t{} \n'.format(train_N / N, int(train_N)))
for i in tqdm(range(int(train_N), int(train_N + val_N))):
xml_name = xml_names[i]
xml_root = os.path.join(xml_path, xml_name)
if xml_name[:-4] in list(img_names_no_suffix):
img_idx = list(img_names_no_suffix).index(xml_name[:-4])
img_name = img_names[img_idx]
img_root = os.path.join(img_path, img_name)
save_txt = os.path.join(val_label_path, xml_name[:-3] + 'txt')
try:
xml2yolotxt(xml_root, img_root=img_root, save_txt=save_txt, labels_name_lst=labels_name_lst)
except:
problem_xmls.append(xml_root)
break
# xml2yolotxt(xml_root, img_root=img_root, save_txt=save_txt, labels_name_lst=labels_name_lst)
shutil.copy(img_root, os.path.join(val_img_path, img_name))
print('\nfinishing vonvert of val data, val_rate:\t{}\t val count:\t{} \n'.format(val_N / N, int(val_N)))
for i in tqdm(range(int(train_N + val_N), int(train_N + val_N + test_N))):
xml_name = xml_names[i]
xml_root = os.path.join(xml_path, xml_name)
if xml_name[:-4] in list(img_names_no_suffix):
img_idx = list(img_names_no_suffix).index(xml_name[:-4])
img_name = img_names[img_idx]
img_root = os.path.join(img_path, img_name)
save_txt = os.path.join(test_label_path, xml_name[:-3] + 'txt')
try:
xml2yolotxt(xml_root, img_root=img_root, save_txt=save_txt, labels_name_lst=labels_name_lst)
except:
problem_xmls.append(xml_root)
break
# xml2yolotxt(xml_root, img_root=img_root, save_txt=save_txt, labels_name_lst=labels_name_lst)
shutil.copy(img_root, os.path.join(test_img_path, img_name))
print('\nfinishing vonvert of test data, test_rate:\t{}\t test count:\t{} \n'.format(test_N / N, int(test_N)))
print( '\n problem xml:{}\n'.format(len(problem_xmls)) )
for probel_path in problem_xmls:
print(probel_path)
def product_yolo_dataset(yaml_path):
f = open(yaml_path, 'rb')
cfg = yaml.load(f, Loader=yaml.FullLoader)
img_path = cfg['img_path']
xml_path = cfg['xml_path']
out_file_path = cfg['path']
labels_name_lst = [v for k,v in cfg['names'].items()]
kwargs = {"train_rate": cfg['train_rate'], "val_rate": cfg['val_rate'], "test_rate": cfg['test_rate']}
convert_data_train(xml_path, img_path, out_file_path, labels_name_lst, **kwargs)
return cfg
def yolo_dataset_demo():
'''
将xml数据格式转换为yolo格式的方法
'''
yaml_path = 'coco128_auto.yaml'
product_yolo_dataset(yaml_path)
def read_yaml(yaml_path):
f = open(yaml_path, 'rb')
cfg = yaml.load(f, Loader=yaml.FullLoader)
return cfg
def del_runsfile():
from pathlib import Path
import sys
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
del_dir(ROOT/'runs/detect/train')
if __name__ == '__main__':
yolo_dataset_demo()
del_runsfile() # 帮忙删除runs文件
注:该代码只需图像文件与对应xml文件,即可按照比列转换train、val、test数据。
三、一键训练/预测的sh内容
1、训练sh文件(train.sh)内容
训练文件为sh文件,只需通过以下命令,实现训练。
sh train.sh
该文件包含虚拟环境切换与自动调用模型训练,其详情如下:
# train.sh
train_weight=/home/ubuntu/Project/tj/auto_project/RTDETR/model_rtdetr/rtdetr-l.pt
echo -e "\n"train time $(date "+%Y-%m-%d")"\n"
# 更换虚拟环境
__conda_setup="$('/home/ubuntu/miniconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)"
if [ $? -eq 0 ]; then
eval "$__conda_setup"
else
if [ -f "/home/ubuntu/miniconda3/etc/profile.d/conda.sh" ]; then
. "/home/ubuntu/miniconda3/etc/profile.d/conda.sh"
else
export PATH="/home/ubuntu/miniconda3/bin:$PATH"
fi
fi
unset __conda_setup
conda activate yolov8
cur_dir=$(cd `dirname $0`;pwd) # 获得当前路径
echo -e "\ncur_dir:"${cur_dir}"\n"
yaml_dir=$cur_dir/coco128_auto.yaml
echo -e "\nyaml_dir:"${yaml_dir}"\n"
#save_dir=$cur_dir/runs/train
#echo -e "\nsave_dir:"$save_dir"\n"
#
#
#if [ -d ${save_dir} ];then
# echo "save_dir 文件存在"
# else
# echo "save_dir文件不存在-->创建文件"
# mkdir -p $save_dir
#fi
cd ${cur_dir}
ls
echo -e "\n\n\n\t\t\t start train ... \n\n\n"
# xml数据转txt数据格式
python auto_tools.py
yolo train model=$train_weight data=$yaml_dir epochs=300 imgsz=640 batch=24 amp=False name=train/exp
2、train.sh内容说明
1、开头有一个重要预训练权重路径,确定使用rtdetr哪个模型,默认为l模型
train_weight=/home/oem/Project/tj/auto_project/RTDETR/model_rtdetr/rtdetr-l.pt
2、最后一句模型运行命令,默认参数命令如下:
yolo train model=
t
r
a
i
n
w
e
i
g
h
t
d
a
t
a
=
train_weight data=
trainw?eightdata=yaml_dir epochs=300 imgsz=640 batch=12 amp=False name=train/exp
3、添加参数
显卡选择参数device,添加 device=0,1或device=0等形式
3、预测sh文件(detect.sh)介绍
预测文件为sh文件,只需通过以下命令,实现训练。
sh detect.sh
该文件包含虚拟环境切换与自动调用模型预测,其详情如下:
# detect.sh
echo -e "\n"detect time $(date "+%Y-%m-%d")"\n"
# 更换虚拟环境
__conda_setup="$('/home/ubuntu/miniconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)"
if [ $? -eq 0 ]; then
eval "$__conda_setup"
else
if [ -f "/home/ubuntu/miniconda3/etc/profile.d/conda.sh" ]; then
. "/home/ubuntu/miniconda3/etc/profile.d/conda.sh"
else
export PATH="/home/ubuntu/miniconda3/bin:$PATH"
fi
fi
unset __conda_setup
conda activate yolov8
cur_dir=$(cd `dirname $0`;pwd) # 获得当前路径
echo -e "\ncur_dir:"${cur_dir}"\n"
yaml_dir=$cur_dir/coco128_auto.yaml
echo -e "\nyaml_dir:"${yaml_dir}"\n"
save_dir=$cur_dir/runs/detect
echo -e "\nsave_dir:"$save_dir"\n"
if [ -d ${save_dir} ];then
echo "save_dir 文件存在"
else
echo "save_dir文件不存在-->创建文件"
mkdir -p $save_dir
fi
cd ${cur_dir}
ls
echo -e "\n\n\n\t\t\t start detect ... \n\n\n"
python predect.py --conf_thres 0.25
4、detect.sh内容说明
1、最后一句模型运行命令,默认参数命令如下:
python predect.py --conf_thres 0.25
2、添加权重与图片保存路径,如下格式
–weights /home/ubuntu/runs/detect/train/exp/weights/best.pt
–save_dir /home/ubuntu/runs/detect/predect/exp
四、训练、预测运行结果显示
1、训练效果展示
2、预测效果展示
总结
本文一个目的,傻瓜式训练与预测,通过sh脚本实现3个任务,
①、虚拟环境自动切换
②、数据格式自动转换,输入为图像文件与对应xml文件自动完成rtdetr模型训练与预测数据格式
③、模型自动训练与预测,且只需执行sh train.sh或 sh detect.sh即可实现
整体脚本:点击这里
文件整体格式如下图:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!