MMdetection3.0 训练DETR问题分析

2023-12-13 06:34:43

MMdetection3.0 训练DETR问题分析

针对在MMdetection3.0框架下训练DETR模型,验证集AP值一直为0.000的原因作出如下分析并得出结论。

条件:
1、NWPU-VHR-10数据集:共650张,训练:验证=611:39;
2、MMdetection3.0框架实验分析;
3、DETR原论文提供源代码实验分析;
4、已在代码中完成了数据类别定义(num_classes)等相关配置的修改。

分析:
1、在MMdetection3.0框架下,只是加载backbone的预训练权重,val上AP始终为0.0000.如下图所示:=》loss收敛较慢,val始终为0.0000.
在这里插入图片描述
在这里插入图片描述
2、在MMdetection3.0框架下,直接加载detr的完整预训练权重。如下图所示:=》存在警告(size mismatch for bbox_head.fc_cls.weight: copying a param with shape torch.Size([81, 256]) from checkpoint, the shape in current model is torch.Size([11, 256]).
size mismatch for bbox_head.fc_cls.bias: copying a param with shape torch.Size([81]) from checkpoint, the shape in current model is torch.Size([11]).
),但训练测试指标还算正常。

=》警告原因:自定义数据集的类别是10+1,而MMdetection3.0提供的是coco数据集与训练权重80+1.
=》因此,需要修改预训练模型的全连接层输出(见下述第4点)。

在这里插入图片描述
在这里插入图片描述3、在MMdetection3.0框架下,直接加载修改后的detr的完整预训练权重训练测试结果见下图所示:=》警告消除,一切正常,并且修改证据权重类别后loss下降变快,val指标更好(不能说更好,只能说更正常)
在这里插入图片描述
在这里插入图片描述4、修改模型权重参数脚本
=》代码中的METAINFO不想修改 不修改也行。
=》主要是pretrained_weights[‘state_dict’][‘bbox_head.fc_cls.weight’].resize_(11, 256)
pretrained_weights[‘state_dict’][‘bbox_head.fc_cls.bias’].resize_(11)

import torch
METAINFO = dict(
    CLASSES=(
        'airplane',
        'ship',
        'storage tank',
        'baseball diamond',
        'tennis court',
        'basketball court',
        'ground track field',
        'harbor',
        'bridge',
        'vehicle',
    ),
    PALETTE=[
        (
            120,
            120,
            120,
        ),
        (
            180,
            120,
            120,
        ),
        (
            6,
            230,
            230,
        ),
        (
            80,
            50,
            50,
        ),
        (
            4,
            200,
            3,
        ),
        (
            120,
            120,
            80,
        ),
        (
            140,
            140,
            140,
        ),
        (
            204,
            5,
            255,
        ),
        (
            230,
            230,
            230,
        ),
        (
            4,
            250,
            7,
        ),
    ])

pretrained_weights = torch.load('/home/admin1/pywork/data/weigh/resnet50-0676ba61.pth')
# 11 是指 数据类别 + 1
pretrained_weights['state_dict']['bbox_head.fc_cls.weight'].resize_(11, 256)
pretrained_weights['state_dict']['bbox_head.fc_cls.bias'].resize_(11)
pretrained_weights['meta']['experiment_name'] = 'detr_r50_8xb2-150e_coco_11'
pretrained_weights['meta']['dataset_meta'] = METAINFO
torch.save(pretrained_weights, "detr_r50_8xb2-150e_coco_%d.pth" % num_classes)

5、DETR原论文提供的源代码训练情况跟MMdetection3.0框架下的情况类似,都必须加载预训练模型,否则就是一直0.000000000000000.

总结分析:
1、NWPU-VHR-10数据量太小导致的问题(90%),等待进一步测试。
2、Transformer模型提出来的时候就已经说明很吃数据,所以没有足够的数据直接使用transformer训练往往效果不好,所以数据量不足的情况下,还是加载预训练权重吧。
3、backbone的权重在模型的比例其实很小,主要还是后面的编码、解码器,所以只加载backbone的权重也没什么用。

总之,数据、数据、数据要足够哇

文章来源:https://blog.csdn.net/MZYYZT/article/details/134946324
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。