基于Centernet的船舶识别系统

2023-12-14 19:12:57

1.研究背景与意义

项目参考AAAI Association for the Advancement of Artificial Intelligence

研究背景与意义:

船舶识别是航海领域中的一个重要问题,对于海上交通管理、海洋资源开发、海上安全等方面具有重要意义。传统的船舶识别方法主要依赖于人工观察和船舶特征提取,但这种方法存在着识别效率低、识别准确性差、对人力资源的依赖性强等问题。因此,研究基于Centernet的船舶识别系统具有重要的现实意义和应用价值。

首先,基于Centernet的船舶识别系统可以提高船舶识别的准确性和效率。Centernet是一种基于中心点的目标检测方法,通过预测目标的中心点位置和目标的边界框,可以实现对目标的准确定位和识别。相比于传统的目标检测方法,Centernet具有更高的准确性和更快的检测速度。在船舶识别中,Centernet可以有效地识别出船舶的位置和边界框,提高识别的准确性和效率。

其次,基于Centernet的船舶识别系统可以减少对人力资源的依赖性。传统的船舶识别方法需要依赖于人工观察和船舶特征提取,需要大量的人力资源进行船舶的识别和分类。而基于Centernet的船舶识别系统可以自动地对船舶进行识别和分类,减少了对人力资源的依赖性,提高了工作效率。

此外,基于Centernet的船舶识别系统还可以应用于海上交通管理、海洋资源开发和海上安全等方面。船舶识别是海上交通管理的重要环节,可以帮助监测和管理海上船舶的数量、位置和航行状态,提高海上交通的安全性和效率。在海洋资源开发中,船舶识别可以帮助监测和管理海上资源的开发情况,保护海洋生态环境,促进可持续发展。在海上安全方面,船舶识别可以帮助监测和管理海上船舶的安全状况,及时发现和处理潜在的安全隐患,保障海上交通的安全性。

综上所述,基于Centernet的船舶识别系统具有重要的现实意义和应用价值。它可以提高船舶识别的准确性和效率,减少对人力资源的依赖性,应用于海上交通管理、海洋资源开发和海上安全等方面。因此,对于船舶识别技术的研究和应用具有重要的意义和价值。

2.图片演示

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.视频演示

基于Centernet的船舶识别系统_哔哩哔哩_bilibili

4.数据集的采集&标注和整理

图片的收集

首先,我们需要收集所需的图片。这可以通过不同的方式来实现,例如使用现有的公开数据集CBYGDatasets。

在这里插入图片描述

labelImg是一个图形化的图像注释工具,支持VOC和YOLO格式。以下是使用labelImg将图片标注为VOC格式的步骤:

(1)下载并安装labelImg。
(2)打开labelImg并选择“Open Dir”来选择你的图片目录。
(3)为你的目标对象设置标签名称。
(4)在图片上绘制矩形框,选择对应的标签。
(5)保存标注信息,这将在图片目录下生成一个与图片同名的XML文件。
(6)重复此过程,直到所有的图片都标注完毕。

由于YOLO使用的是txt格式的标注,我们需要将VOC格式转换为Centernet格式。可以使用各种转换工具或脚本来实现。

下面是一个简单的方法是使用Python脚本,该脚本读取XML文件,然后将其转换为Centernet所需的txt格式。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import xml.etree.ElementTree as ET
import os

classes = []  # 初始化为空列表

CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))

def convert(size, box):
    dw = 1. / size[0]
    dh = 1. / size[1]
    x = (box[0] + box[1]) / 2.0
    y = (box[2] + box[3]) / 2.0
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x * dw
    w = w * dw
    y = y * dh
    h = h * dh
    return (x, y, w, h)

def convert_annotation(image_id):
    in_file = open('./label_xml\%s.xml' % (image_id), encoding='UTF-8')
    out_file = open('./label_txt\%s.txt' % (image_id), 'w')  # 生成txt格式文件
    tree = ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)

    for obj in root.iter('object'):
        cls = obj.find('name').text
        if cls not in classes:
            classes.append(cls)  # 如果类别不存在,添加到classes列表中
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
             float(xmlbox.find('ymax').text))
        bb = convert((w, h), b)
        out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')

xml_path = os.path.join(CURRENT_DIR, './label_xml/')

# xml list
img_xmls = os.listdir(xml_path)
for img_xml in img_xmls:
    label_name = img_xml.split('.')[0]
    print(label_name)
    convert_annotation(label_name)

print("Classes:")  # 打印最终的classes列表
print(classes)  # 打印最终的classes列表

整理数据文件夹结构

我们需要将数据集整理为以下结构:

-----data
   |-----train
   |   |-----images
   |   |-----labels
   |
   |-----valid
   |   |-----images
   |   |-----labels
   |
   |-----test
       |-----images
       |-----labels

确保以下几点:

所有的训练图片都位于data/train/images目录下,相应的标注文件位于data/train/labels目录下。
所有的验证图片都位于data/valid/images目录下,相应的标注文件位于data/valid/labels目录下。
所有的测试图片都位于data/test/images目录下,相应的标注文件位于data/test/labels目录下。
这样的结构使得数据的管理和模型的训练、验证和测试变得非常方便。

模型训练
 Epoch   gpu_mem       box       obj       cls    labels  img_size
 1/200     20.8G   0.01576   0.01955  0.007536        22      1280: 100%|██████████| 849/849 [14:42<00:00,  1.04s/it]
           Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100%|██████████| 213/213 [01:14<00:00,  2.87it/s]
             all       3395      17314      0.994      0.957      0.0957      0.0843

 Epoch   gpu_mem       box       obj       cls    labels  img_size
 2/200     20.8G   0.01578   0.01923  0.007006        22      1280: 100%|██████████| 849/849 [14:44<00:00,  1.04s/it]
           Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100%|██████████| 213/213 [01:12<00:00,  2.95it/s]
             all       3395      17314      0.996      0.956      0.0957      0.0845

 Epoch   gpu_mem       box       obj       cls    labels  img_size
 3/200     20.8G   0.01561    0.0191  0.006895        27      1280: 100%|██████████| 849/849 [10:56<00:00,  1.29it/s]
           Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100%|███████   | 187/213 [00:52<00:00,  4.04it/s]
             all       3395      17314      0.996      0.957      0.0957      0.0845

5.核心代码讲解

5.1 centernet.py


class CenterNet(object):
    _defaults = {
        "model_path"        : 'logs/best_epoch_weights.pth',
        "classes_path"      : 'model_data/voc_classes.txt',
        "backbone"          : 'resnet50',
        "input_shape"       : [256, 256],
        "confidence"        : 0.3,
        "nms_iou"           : 0.3,
        "nms"               : True,
        "letterbox_image"   : False,
        "cuda"              : True
    }

    @classmethod
    def get_defaults(cls, n):
        if n in cls._defaults:
            return cls._defaults[n]
        else:
            return "Unrecognized attribute name '" + n + "'"

    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)
        for name, value in kwargs.items():
            setattr(self, name, value)
            self._defaults[name] = value 
            
        self.class_names, self.num_classes  = get_classes(self.classes_path)

        hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
        self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))

        self.generate()
        
        show_config(**self._defaults)

    def generate(self, onnx=False):
        assert self.backbone in ['resnet50', 'hourglass']
        if self.backbone == "resnet50":
            self.net = CenterNet_Resnet50(num_classes=self.num_classes, pretrained=False)
        else:
            self.net = CenterNet_HourglassNet({'hm': self.num_classes, 'wh': 2, 'reg':2})

        device      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.net.load_state_dict(torch.load(self.model_path, map_location=device))
        self.net    = self.net.eval()
        print('{} model, and classes loaded.'.format(self.model_path))
        if not onnx:
            if self.cuda:
                self.net = torch.nn.DataParallel(self.net)
                self.net = self.net.cuda()

    def detect_image(self, image, crop = False, count = False):
        image_shape = np.array(np.shape(image)[0:2])
        image       = cvtColor(image)
        image_data  = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
        image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)

        with torch.no_grad():
            images = torch.from_numpy(np.asarray(image_data)).type(torch.FloatTensor)
            if self.cuda:
                images = images.cuda()
            outputs = self.net(images)
            if self.backbone == 'hourglass':
                outputs = [outputs[-1]["hm"].sigmoid(), outputs[-1]["wh"], outputs[-1]["reg"]]
            outputs = decode_bbox(outputs[0], outputs[1], outputs[2], self.confidence, self.cuda)

            results = postprocess(outputs, self.nms, image_shape, self.input_shape, self.letterbox_image, self.nms_iou)
            
            if results[0] is None:
                return image

            top_label   = np.array(results[0][:, 5], dtype = 'int32')
            top_conf    = results[0][:, 4]
            top_boxes   = results[0][:, :4]

        font = ImageFont.truetype(font='model_data/simhei.ttf', size=np.floor(3e-2 * np.shape(image)[1] + 0.5).astype('int32'))
        thickness = max((np.shape(image)[0] + np.shape(image)[1]) // self.input_shape[0], 1)
        
        if count:
            print("top_label:", top_label)
            classes_nums    = np.zeros([self.num_classes])
            for i in range(self.num_classes):
                num = np.sum(top_label == i)
                if num > 0:
                    print(self.class_names[i], " : ", num)
                classes_nums[i] = num
            print("classes_nums:", classes_nums)
        
        if crop:
            for i, c in list(enumerate(top_label)):
                top, left, bottom, right = top_boxes[i]
                top     = max(0, np.floor(top).astype('int32'))
                left    = max(0, np.floor(left).astype('int32'))
                bottom  = min(image.size[1], np.floor(bottom).astype('int32'))
                right   = min(image.size[0], np.floor(right).astype('int32'))
                
                dir_save_path = "img_crop"
                if not os.path.exists(dir_save_path):
                    os.makedirs(dir_save_path)
                crop_image = image.crop([left, top, right, bottom])
                crop_image.save(os.path.join(dir_save_path, "crop_" + str(i) + ".png"), quality=95, subsampling=0)
                print("save crop_" + str(i) + ".png to " + dir_save_path)
        
        label_list = []
        for i, c in list(enumerate(top_label)):
            predicted_class = self.class_names[int(c)]
            box             = top_boxes[i]
            score           = top_conf[i]

            top, left, bottom, right = box

            top     = max(0, np.floor(top).astype('int32'))
            left    = max(0, np.floor(left).astype('int32'))
            bottom  = min(image.size[1], np.floor(bottom).astype('int32'))
            right   = min(image.size[0], np.floor(right).astype('int32'))
            label_list.append(predicted_class)
            label = '{} {:.2f}'.format(predicted_class, score)
            draw = ImageDraw.Draw(image)
           

该程序文件是一个用于目标检测的Centernet模型的实现。Centernet模型是一种基于中心点的目标检测算法,可以用于检测图像中的多个目标。

该程序文件包含以下功能:

  1. 初始化Centernet模型,包括加载模型权重和类别信息。
  2. 对输入的图像进行预处理,包括转换为RGB图像、调整大小和归一化。
  3. 使用Centernet模型对图像进行预测,得到目标的位置和类别信息。
  4. 根据预测结果进行后处理,包括非极大抑制和解码。
  5. 绘制预测结果的边界框和类别标签。

该程序文件还提供了一些可配置的参数,包括模型路径、类别文件路径、主干网络类型、输入图像大小、置信度阈值、非极大抑制的IOU阈值等。

在使用该程序文件进行目标检测时,需要根据实际情况修改模型路径、类别文件路径和主干网络类型等参数。

5.2 get_map.py


class MapCalculator:
    def __init__(self, map_mode, classes_path, MINOVERLAP, confidence, nms_iou, score_threhold, map_vis, VOCdevkit_path, map_out_path):
        self.map_mode = map_mode
        self.classes_path = classes_path
        self.MINOVERLAP = MINOVERLAP
        self.confidence = confidence
        self.nms_iou = nms_iou
        self.score_threhold = score_threhold
        self.map_vis = map_vis
        self.VOCdevkit_path = VOCdevkit_path
        self.map_out_path = map_out_path

    def load_model(self):
        print("Load model.")
        self.centernet = CenterNet(confidence=self.confidence, nms_iou=self.nms_iou)
        print("Load model done.")

    def get_predict_result(self):
        print("Get predict result.")
        for image_id in tqdm(self.image_ids):
            image_path = os.path.join(self.VOCdevkit_path, "VOC2007/JPEGImages/" + image_id + ".jpg")
            image = Image.open(image_path)
            if self.map_vis:
                image.save(os.path.join(self.map_out_path, "images-optional/" + image_id + ".jpg"))
            self.centernet.get_map_txt(image_id, image, self.class_names, self.map_out_path)
        print("Get predict result done.")

    def get_ground_truth_result(self):
        print("Get ground truth result.")
        for image_id in tqdm(self.image_ids):
            with open(os.path.join(self.map_out_path, "ground-truth/" + image_id + ".txt"), "w") as new_f:
                root = ET.parse(os.path.join(self.VOCdevkit_path, "VOC2007/Annotations/" + image_id + ".xml")).getroot()
                for obj in root.findall('object'):
                    difficult_flag = False
                    if obj.find('difficult') != None:
                        difficult = obj.find('difficult').text
                        if int(difficult) == 1:
                            difficult_flag = True
                    obj_name = obj.find('name').text
                    if obj_name not in self.class_names:
                        continue
                    bndbox = obj.find('bndbox')
                    left = bndbox.find('xmin').text
                    top = bndbox.find('ymin').text
                    right = bndbox.find('xmax').text
                    bottom = bndbox.find('ymax').text

                    if difficult_flag:
                        new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom))
                    else:
                        new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
        print("Get ground truth result done.")

    def get_map(self):
        print("Get map.")
        get_map(self.MINOVERLAP, True, score_threhold=self.score_threhold, path=self.map_out_path)
        print("Get map done.")

    def get_coco_map(self):
        print("Get map.")
        get_coco_map(class_names=self.class_names, path=self.map_out_path)
        print("Get map done.")

    def calculate_map(self):
        self.image_ids = open(os.path.join(self.VOCdevkit_path, "VOC2007/ImageSets/Main/test.txt")).read().strip().split()

        if not os.path.exists(self.map_out_path):
            os.makedirs(self.map_out_path)
        if not os.path.exists(os.path.join(self.map_out_path, 'ground-truth')):
            os.makedirs(os.path.join(self.map_out_path, 'ground-truth'))
        if not os.path.exists(os.path.join(self.map_out_path, 'detection-results')):
            os.makedirs(os.path.join(self.map_out_path, 'detection-results'))
        if not os.path.exists(os.path.join(self.map_out_path, 'images-optional')):
            os.makedirs(os.path.join(self.map_out_path, 'images-optional'))

        self.class_names, _ = get_classes(self.classes_path)

        if self.map_mode == 0 or self.map_mode == 1:
            self.load_model()
            self.get_predict_result()

        if self.map_mode == 0 or self.map_mode == 2:
            self.get_ground_truth_result()

        if self.map_mode == 0 or self.map_mode == 3:
            self.get_map()

        if self.map_mode == 4:
            self.get_coco_map()


该程序文件名为get_map.py,主要功能是计算目标检测模型的mAP(mean Average Precision)。

程序的主要流程如下:

  1. 导入所需的库和模块,包括os、xml.etree.ElementTree、PIL、tqdm等。
  2. 导入自定义的CenterNet类和一些辅助函数。
  3. 根据map_mode的值确定计算内容,map_mode为0表示计算整个mAP计算流程,包括获得预测结果、获得真实框、计算VOC_map;map_mode为1表示仅获得预测结果;map_mode为2表示仅获得真实框;map_mode为3表示仅计算VOC_map;map_mode为4表示利用COCO工具箱计算当前数据集的0.50:0.95map。
  4. 设置一些参数,如classes_path(指定需要测量VOC_map的类别)、MINOVERLAP(指定想要获得的mAP0.x)、confidence(预测时使用的置信度阈值)、nms_iou(预测时使用的非极大抑制值的大小)等。
  5. 根据VOCdevkit_path指向的VOC数据集文件夹,读取测试集图像的image_ids。
  6. 创建输出文件夹map_out及其子文件夹。
  7. 根据map_mode的值,执行相应的操作:
    • 如果map_mode为0或1,加载模型,获取预测结果,并将结果保存为txt文件。
    • 如果map_mode为0或2,获取真实框,并将结果保存为txt文件。
    • 如果map_mode为0或3,计算VOC_map,并将结果保存为txt文件。
    • 如果map_mode为4,利用COCO工具箱计算当前数据集的0.50:0.95map。
  8. 打印相应的提示信息,表示计算完成。

该程序的主要功能是计算目标检测模型的mAP,并根据map_mode的值决定计算的内容。

5.3 predict.py


class ObjectDetection:
    def __init__(self):
        self.centernet = CenterNet()

    def predict_image(self, image_path, crop=False, count=False):
        try:
            image = Image.open(image_path)
        except:
            print('Open Error! Try again!')
            return None
        else:
            r_image = self.centernet.detect_image(image, crop=crop, count=count)
            return r_image

    def detect_video(self, video_path, video_save_path="", video_fps=25.0):
        capture = cv2.VideoCapture(video_path)
        if video_save_path != "":
            fourcc = cv2.VideoWriter_fourcc(*'XVID')
            size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
            out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)

        ref, frame = capture.read()
        if not ref:
            raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")

        fps = 0.0
        while (True):
            t1 = time.time()
            # 读取某一帧
            ref, frame = capture.read()
            if not ref:
                break
            # 格式转变,BGRtoRGB
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            # 转变成Image
            frame = Image.fromarray(np.uint8(frame))
            # 进行检测
            frame = np.array(self.centernet.detect_image(frame))
            # RGBtoBGR满足opencv显示格式
            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

            fps = (fps + (1. / (time.time() - t1))) / 2
            print("fps= %.2f" % (fps))
            frame = cv2.putText(frame, "fps= %.2f" % (fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)

            cv2.imshow("video", frame)
            c = cv2.waitKey(1) & 0xff
            if video_save_path != "":
                out.write(frame)

            if c == 27:
                capture.release()
                break

        print("Video Detection Done!")
        capture.release()
        if video_save_path != "":
            print("Save processed video to the path :" + video_save_path)
            out.release()
        cv2.destroyAllWindows()

    def measure_fps(self, fps_image_path, test_interval=100):
        img = Image.open(fps_image_path)
        tact_time = self.centernet.get_FPS(img, test_interval)
        print(str(tact_time) + ' seconds, ' + str(1 / tact_time) + 'FPS, @batch_size 1')

    def detect_directory(self, dir_origin_path, dir_save_path):
        import os
        from tqdm import tqdm

        img_names = os.listdir(dir_origin_path)
        for img_name in tqdm(img_names):
            if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
                image_path = os.path.join(dir_origin_path, img_name)
                image = Image.open(image_path)
                r_image = self.centernet.detect_image(image)
                if not os.path.exists(dir_save_path):
                    os.makedirs(dir_save_path)
                r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)

    def generate_heatmap(self, image_path, heatmap_save_path):
        try:
            image = Image.open(image_path)
        except:
            print('Open Error! Try again!')
            return None
        else:
            self.centernet.detect_heatmap(image, heatmap_save_path)

    def export_onnx(self, simplify=True, onnx_save_path="model_data/models.onnx"):
        self.centernet.convert_to_onnx(simplify, onnx_save_path)


该程序文件名为predict.py,主要实现了单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能。程序通过指定mode参数来选择不同的功能模式。

具体功能模式如下:

  • “predict”:单张图片预测模式。用户可以输入图片文件名,程序会读取并进行目标检测,并显示检测结果。
  • “video”:视频检测模式。用户可以指定视频路径,程序会读取视频并进行目标检测,可以选择保存检测结果视频。
  • “fps”:测试FPS模式。程序会读取指定的测试图片,计算模型的FPS(每秒处理的帧数)。
  • “dir_predict”:目录遍历检测模式。程序会遍历指定目录下的所有图片文件,并进行目标检测,可以选择保存检测结果图片。
  • “heatmap”:预测结果热力图可视化模式。用户可以输入图片文件名,程序会读取并进行目标检测,并生成预测结果的热力图。
  • “export_onnx”:导出模型为ONNX格式模式。程序会将模型导出为ONNX格式文件。

在程序中,用户可以根据需要修改各个模式下的参数,如是否进行目标截取、目标计数、视频路径、视频保存路径、视频帧率、测试图片路径、目录遍历路径、热力图保存路径等。

在"predict"模式下,用户可以根据需要对预测过程进行修改,如保存图片、截取目标、在预测图上写额外的字等。

在"video"模式下,用户可以选择是否保存检测结果视频,并可以通过按下ESC键退出程序。

在"dir_predict"模式下,程序会遍历指定目录下的所有图片文件,并将检测结果保存到指定的保存路径。

在"heatmap"模式下,用户可以根据需要生成预测结果的热力图。

在"export_onnx"模式下,程序会将模型导出为ONNX格式文件,可以选择是否进行Simplify操作,并指定导出的ONNX文件保存路径。

如果mode参数不是上述模式之一,则会抛出错误提示用户指定正确的模式。

5.4 summary.py


class NetworkInfo:
    def __init__(self, input_shape, num_classes):
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = CenterNet_Resnet50().to(self.device)
    
    def get_summary(self):
        summary(self.model, (3, self.input_shape[0], self.input_shape[1]))
    
    def get_flops_params(self):
        dummy_input = torch.randn(1, 3, self.input_shape[0], self.input_shape[1]).to(self.device)
        flops, params = profile(self.model.to(self.device), (dummy_input, ), verbose=False)
        flops = flops * 2
        flops, params = clever_format([flops, params], "%.3f")
        return flops, params


这个程序文件名为summary.py,主要用于查看网络参数。程序首先导入了必要的库,包括torch、thop和torchsummary。然后定义了输入形状和类别数。接下来,程序将模型加载到设备上,并使用torchsummary库打印出模型的概要信息。然后,程序创建了一个虚拟输入,并使用thop库计算模型的FLOPs和参数数量。最后,程序将FLOPs和参数数量格式化为可读的形式,并打印出来。

5.5 train.py


class CenterNetTrainer:
    def __init__(self, cuda=True, distributed=False, sync_bn=False, fp16=False, classes_path='model_data/voc_classes.txt', model_path='model_data/centernet_resnet50_voc.pth', input_shape=[256, 256], backbone="resnet50", pretrained=False):
        self.cuda = cuda
        self.distributed = distributed
        self.sync_bn = sync_bn
        self.fp16 = fp16
        self.classes_path = classes_path
        self.model_path = model_path
        self.input_shape = input_shape
        self.backbone = backbone
        self.pretrained = pretrained

    def train(self):
        # ... 
        pass


该程序文件是用于训练目标检测模型的。文件中包含了训练的一些注意事项和参数设置。

程序首先导入了所需的库和模块,包括numpy、torch、torchvision等。然后定义了一些训练所需的参数,如是否使用GPU、是否使用分布式训练、是否使用混合精度训练等。

接下来,程序定义了一些辅助函数和类,包括数据集加载、模型构建、优化器设置等。这些函数和类的具体实现可以在其他文件中找到。

最后,程序的主函数中进行了一些参数的设置,如是否使用CUDA、是否使用分布式训练、是否使用sync_bn等。然后加载了数据集的类别信息和预训练权重文件。

最后,程序调用了fit_one_epoch函数进行模型的训练。fit_one_epoch函数会根据参数设置进行模型的训练,并保存训练过程中的损失值和权重文件。

总的来说,该程序文件是用于训练目标检测模型的,其中包含了训练的一些注意事项和参数设置,以及模型的训练过程。

6.系统整体结构

整体功能和构架概述:
该项目是一个基于CenterNet的船舶识别系统,使用PyQt5库创建了一个图形用户界面,用户可以通过界面选择图片文件并进行预处理和检测操作。程序中包含了深度学习模型的定义和相关的算法实现,以及用于数据加载、训练、评估和可视化的工具函数。

以下是每个文件的功能概述:

文件路径功能概述
centernet.py实现了Centernet模型的目标检测功能
get_map.py计算目标检测模型的mAP(mean Average Precision)
predict.py实现了单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能
summary.py查看网络参数的概要信息
train.py用于训练目标检测模型
ui.py创建了一个基于PyQt5的图形用户界面,用于船舶识别系统的操作
vision_for_centernet.py提供了一些图像处理和可视化的函数
voc_annotation.py用于生成VOC格式的标注文件
nets/centernet.py实现了Centernet模型的网络结构
nets/centernet_training.py实现了Centernet模型的训练过程
nets/hourglass.py实现了Hourglass网络结构
nets/resnet50.py实现了ResNet50网络结构
utils/callbacks.py提供了一些训练过程中的回调函数
utils/dataloader.py提供了数据加载和预处理的函数
utils/utils.py提供了一些通用的工具函数
utils/utils_bbox.py提供了一些处理边界框的工具函数
utils/utils_fit.py提供了一些模型训练和优化的工具函数
utils/utils_map.py提供了一些计算mAP的工具函数
utils/init.py初始化文件,标识该目录为Python包

以上是每个文件的大致功能概述,具体实现和细节可以在各个文件中查看。

7.目前遥感船舶检测的现状

遥感图像中的船舶检测在各种应用中起着至关重要的作用,近年来受到越来越多的关注。然而,现有的多角度船舶检测方法通常是在一组预定义的旋转锚框上发展起来的。这些预定义的框不仅导致不准确的角度预测,而且还引入额外的超参数和高计算成本。此外,现有的检测方法还没有充分利用船舶尺寸的先验知识,影响了检测精度的提高。

基于高分辨率光学遥感图像的船舶检测在非法走私、港口管理、目标侦察等领域有着广泛的应用。最近几十年来,船舶检测受到了越来越多的关注,并得到了广泛的研究。然而,由于遥感图像中船舶的任意方位、密集停放场景和复杂背景,使得船舶检测成为一项极具挑战性的任务。为了处理多方向性问题,现有的方法通常使用一系列预定义的锚点,这有以下缺点。

1.角度回归不准确。
图1(a)-(d)展示了任意定向船舶的四种不同表示。由于遥感图像中的船舶通常是条带状的,因此IOU的分数对边界框的角度非常敏感。如图1(e)所示,GT框是宽高比为10:1的船舶边界框。红色旋转框是通过将GT框旋转5°的小角度来生成的。可以观察到,如此小的角度变化将这两个盒子之间的IoU减少到0.63。因此,基于锚点的检测器通过IoU分数定义正锚点和负锚点,通常会遇到不同锚点之间的不平衡问题,从而导致检测性能退化。此外,船舶的角度是一个周期函数,在边界(0°或180°处)是不连续的,如图1(f)所示。这种不连续性还会导致性能下降。
在这里插入图片描述

2.超参数过多和计算成本高
现有方法通常使用定向边界框作为锚来处理旋转的对象,从而引入过多的超参数,如框大小、纵横比和方向角。注意,这些超参数必须手动调整以适应新的场景,这限制了这些方法的泛化能力。预定义的基于锚的方法通常需要大量的锚框。例如,在RRPN中,在旋转的锚框中使用了6个不同的方向,并且在其特征图上的每个像素上总共有24个锚。在计算IoU分数和执行非极大值抑制(NMS)算法时,大量的锚盒会引入过多的计算开销。

3.未充分利用船舶的先验信息。
以往的船舶检测器大多直接采用与遥感和场景文本检测相同的旋转检测算法。然而,遥感图像中的舰船有其独特的特点。通过对遥感图像中某类船舶的地面样本距离(GSD)进行归一化处理,发现该类船舶在遥感图像中具有相对固定的尺寸范围。船的大小和船头的位置是探测的重要线索。然而,现有的方法没有充分利用这些先验信息。

基于自然场景中的无锚点探测器CenterNet,本文提出了一种one-stage、无锚点、无NMS的遥感图像中任意方向船舶检测方法。我们把船描述成旋转的盒子,上面有一个代表方向的头部点。具体地说,方向不变特征映射首先由方向不变模型生成。然后,选择中心特征图的峰值作为中心点,在每个中心点的相应特征映射上回归偏移、对象大小和头部位置。最后,利用目标大小调整分类分数。

在这里插入图片描述

(首先利用完全卷积骨干网和方向不变模型(OIM)生成特征映射。然后选取中心点特征图的峰值作为中心点。然后,在每个中心点位置的相应特征地图上回归中心点偏移、对象大小和头部回归位置。通过在头部特征图上提取置信度大于0.1的峰值来收集潜在的头部点。最终的头部位置是通过将每个回归位置分配给其最近的潜在头部点,然后添加头部偏移来获得的。)
在这里插入图片描述

8.Centernet简介

CenterNet算法简介

CenterNet是一个基于Anchor-free的目标检测算法,该算法是在CornerNet算法的基础上改进而来的。与单阶段目标检测算法yolov3相比,该算法在保证速度的前提下,精度提升了4个百分点。与其它的单阶段或者双阶段目标检测算法相比,该算法具有以下的优势:

(1)该算法去除低效复杂的Anchors操作,进一步提升了检测算法性能;
(2)该算法直接在heatmap图上面执行了过滤操作,去除了耗时的NMS后处理操作,进一步提升了整个算法的运行速度;
(3)该算法不仅可以应用到2D目标检测中,经过简单的改变它还可以应用3D目标检测与人体关键点检测等其它的任务中,即具有很好的通用性。

CenterNet网络结构

在这里插入图片描述

上图展示了CenterNet网络的整体结构,整个网络结构比较简单。

(1)最左边表示输入图片。输入图片需要裁减到512*512大小,即长边缩放到512,短边补0,具体的效果如下图所示,由于原图的W>512,因而直接将其缩放为512;由于原图的H<512,因而对其执行补0操作;

(2)中间表示基准网络,论文中尝试了Hourglass、ResNet与DLA3种网络架构,各个网络架构的精度及帧率为:Resnet-18 with up-convolutional layers:28.1% coco and 142 FPS、DLA-34:37.4% COCOAP and 52 FPS、Hourglass-104:45.1% COCOAP and 1.4 FPS。
在这里插入图片描述

上图展示了3中不同的网络架构,图(a)表示Hourglass网络,该网络是在ECCV2016中的Stacked hourglass networks for human pose estimation论文中提出的一种网络,用来解决人体位姿估计问题,其思路主要通过将多个漏斗形状的网络堆叠起来,从而获得多尺度信息,具体的细节请参考该博客。图(b)表示带有反卷积的ResNet网络,作者在每一个上采样层之前增加了一个3*3的膨胀卷积,即先使用反卷积来改变膨胀卷积的通道个数,然后使用反卷积来对特征映射执行上采样操作。图?表示用于语义分割的DLA34网络;图d表示改变的DLA34网络,该网络在原始的DLA34网络的基础上增加了更多的残差连接,该网络将Dense_Connection与FPN的思路融合起来,前者源于DenseNet,可以用来聚合语义信息,能够提升模型推断是“what”的能力;后者源于聚合空间信息,能够提升模型推断在“where”的能力,具体的细节如下图所示。
在这里插入图片描述

(3)最右边表示预测模块,该模块包含3个分支,具体包括中心点heatmap图分支、中心点offset分支、目标大小分支。heatmap图分支包含C个通道,每一个通道包含一个类别,heatmap中白色的亮区域表示目标的中心 点位置;中心点offset分支用来弥补将池化后的低heatmap上的点映射到原图中所带来的像素误差;目标大小分支用来预测目标矩形框的w与h偏差值。

训练阶段Heatmap生成

CenterNet将目标检测问题转换成中心点预测问题,即用目标的中心点来表示该目标,并通过预测目标中心点的偏移量与宽高来获取目标的矩形框。Heatmap表示分类信息,每一个类别将会产生一个单独的Heatmap图。对于每张Heatmap图而言,当某个坐标处包含目标的中心点时,则会在该目标处产生一个关键点,我们利用高斯圆来表示整个关键点,下图展示了具体的细节。
在这里插入图片描述

生成Heatmap图的具体步骤如下所示:

步骤1-将输入的图片缩放成512512大小,对该图像执行R=4的下采样操作之后,获得一个128128大小的Heatmap图;

步骤2-将输入图片中的Box缩放到128*128大小的Heatmap图上面,计算该Box的中心点坐标,并执行向下取整操作,并将其定义为point;

步骤3-根据目标Box大小来计算高斯圆的半径R;
??关于高斯圆的半径确定,主要还是依赖于目标box的宽高, 实际情况下通常会取IOU=0.7,即下图中的overlap=0.7作为临界值,然后分别计算出三种情况的半径,取最小值作为高斯核的半径R,具体的实现细节如下图所示:
(1)情况1-预测框pred_bbox包含gt_bbox框,对应于下图中的第1种情况,将整个IoU公式展开之后,成为一个二元一次方程的求解问题。
(2)情况2-gt_bbox包含预测框pred_bbox框,对应于下图中的第2种情况,将整个IoU公式展开之后,成为一个二元一次方程的求解问题。
(3)情况3-gt_bbox与预测框pred_bbox框相互重叠,对应于下图中的第3种情况,将整个IoU公式展开之后,成为一个二元一次方程的求解问题。
在这里插入图片描述

步骤4-在128*128大小的Heatmap图上面,以point为中心点,半径为R计算高斯值,point点处数值最大,随着半径R的增加数值不断减小;

上图展示了一个样例,左边表示经过裁剪之后的512512大小的输入图片,右边表示经过高斯操作之后生成的128128大小的Heatmap图。由于图中包含两只猫,这两只猫属于一个类别,因此在同一个Heatmap图上面生成了两个高斯圆,高斯圆的大小与矩形框的大小有关。

Heatmap上应用高斯核

Heatmap上的关键点之所以采用二维高斯核来表示,是由于对于在目标中心点附近的一些点,其预测出来的pre_box和gt_box的IOU可能会大于0.7,不能直接对这些预测值进行惩罚,需要温和一点,所以采用高斯核。该问题在Corner算法中就已经存在,如下图所示,我们在设置gt_bbox的heatmap时,不仅仅只在中心点的位置设置标签1,图中红色的矩形框表示gt_bbox,但是绿色的矩形框其实也可以很好的包围该目标,即我们在检测的过程中如何获得像绿色框这样的矩形框时,我们也要保存它。通俗一点来讲,只要预测的corner点在中心点的某一个半径r内,而且该矩形框与gt_bbox之间的IoU大于0.7时,我们将这些点处的值设置为一个高斯分布的数值,而不是数值0。

在这里插入图片描述

9.系统整合

下图完整源码&数据集&环境部署视频教程&自定义UI界面

在这里插入图片描述

参考博客《基于Centernet的船舶识别系统》

10.参考文献


[1]刘鑫,黄进,杨涛,等.改进CenterNet的无人机小目标捕获检测方法[J].计算机工程与应用.2022,58(14).DOI:10.3778/j.issn.1002-8331.2111-0033 .

[2]魏玮,杨茹,朱叶.改进CenterNet的遥感图像目标检测[J].计算机工程与应用.2021,(6).DOI:10.3778/j.issn.1002-8331.2007-0052 .

[3]方钧婷,谭晓阳.注意力级联网络的金属表面缺陷检测算法[J].计算机科学与探索.2021,(7).DOI:10.3778/j.issn.1673-9418.2007005 .

[4]黄致君,桑庆兵.改进R-FCN的船舶识别方法[J].计算机科学与探索.2020,(6).DOI:10.3778/j.issn.1673-9418.1904061 .

[5]Girshick, R.,Donahue, J.,Darrell, T.,等.Rich Feature Hierarchies for Accurate Object Detection and Semantic Segmentation[C].

[6]Ross Girshick.Fast R-CNN[C].

文章来源:https://blog.csdn.net/xuehaisj/article/details/134861571
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。