mmdet 从2.x 模型代码迁移到3.x 记录(以mask rcnn 为例)

2023-12-13 13:26:19

mmdet 2.x 迁移到 3.x


)

step1 根据官方文档修改 配置文件:

https://mmdetection.readthedocs.io/zh-cn/latest/migration/config_migration.html

step2 修改自定义的dataset模块

try:
	# mmdet 2.x
    from mmdet.datasets.builder import DATASETS
    from mmdet.datasets.pipelines import Compose
    from mmdet.datasets.custom import CustomDataset
    from mmcv.utils import print_log
except Exception as e:
    # mmdet 3.x
    from mmdet.registry import DATASETS
    from mmcv.transforms import Compose
    from mmengine.dataset import BaseDataset as CustomDataset
    from mmengine.logging import print_log
@DATASETS.register_module()
class YourCustomDataset(CustomDataset):

# 此处省略 部分代码
# 
# 将 类初始化中的参数 img_prefix 改为 data_prefix='', test_mode=True,增加:*args,**kwargs
# 类初始化的时候 __init__ 增加
self._metainfo = {}
self._metainfo['classes'] = self.CLASSES

	# 增加一个方法
	@classmethod
    def get_classes(cls, classes=None):
        """Get class names of current dataset.

        Args:
            classes (Sequence[str] | str | None): If classes is None, use
                default CLASSES defined by builtin dataset. If classes is a
                string, take it as a file name. The file contains the name of
                classes where each line contains one class name. If classes is
                a tuple or list, override the CLASSES defined by the dataset.

        Returns:
            tuple[str] or list[str]: Names of categories of the dataset.
        """
        if classes is None:
            return cls.CLASSES

        if isinstance(classes, str):
            # take it as a file path
            class_names = mmcv.list_from_file(classes)
        elif isinstance(classes, (tuple, list)):
            class_names = classes
        else:
            raise ValueError(f'Unsupported type {type(classes)} of classes.')

        return class_names



step3 在模型调用时修改调用的代码

try:
    # mmdet 2.x
    from mmcv.parallel import collate, scatter
    from mmdet.datasets.pipelines import Compose
except Exception as e:
	# mmdet 3.x
    print(mmcv.__version__,dir(mmcv))
    from mmcv.transforms import Compose
    from mmdet.structures.bbox import bbox2roi
    # from mmcv.parallel import collate, scatter 新版本已经删除

数据处理中的

cfg = self.model.cfg
device = next(self.model.parameters()).device
print("当前预测使用device{}".format(device))
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
test_pipeline = Compose(test_pipeline)
data = dict(img=image)
data = test_pipeline(data)
try:
    # mmdet 2.x
    if device == "cpu":
        data = scatter(collate([data], samples_per_gpu=1), [device])[0]
    else:
        for m in self.model.modules():
            assert not isinstance(
                m, RoIPool
            ), 'CPU inference with RoIPool is not supported currently.'

        data = scatter(collate([data], samples_per_gpu=1), [device])[0]
     try:
       img_meta = data["img_meta"][0]  # version 1.x
     except Exception as e:
         img_meta = data["img_metas"][0]
except Exception as e:
    #
    print("mmdet的版本是 3.x", e)
    print(data["inputs"].shape)
    from torchvision import transforms as tfs
    normalize = tfs.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
 
    print(type(data["data_samples"]), data["data_samples"].keys())

    image_size = data["data_samples"].img_shape


    img_aug = tfs.Compose([tfs.Resize(image_size), tfs.CenterCrop(image_size), tfs.ToTensor(), normalize])
    data['inputs'] = [data['inputs']]
    from PIL import Image
    data["img"] = [torch.unsqueeze(img_aug(Image.fromarray(image)).to(device), 0)]

    data['data_samples'] = [data['data_samples']]
    data["img_meta"] = data["data_samples"]
    img_meta = data["img_meta"][0] 

img = data["img"][0]
proposals = None
rescale = True
with torch.no_grad():
    x = self.model.extract_feat(img)

将调用中的



proposal_list = self.model.rpn_head.simple_test_rpn(x, img_meta)
# pdb.set_trace()
det_bboxes, det_labels = self.model.roi_head.simple_test_bboxes(
    x, img_meta, proposal_list, self.model.test_cfg.rcnn, rescale=rescale)  # mmdet version 2.x
det_bboxes = det_bboxes[0]
det_labels = det_labels[0]
if "cuda" in str(self.device):
    det_bboxes = det_bboxes.detach().cpu()
    det_labels = det_labels.detach().cpu()
ori_shape = img_meta[0]['ori_shape']
scale_factor = img_meta[0]['scale_factor']
if det_bboxes.shape[0] == 0:
    return []
else:
    # if det_bboxes is rescaled to the original image size, we need to
    # rescale it back to the testing scale to obtain RoIs.
    _bboxes = (det_bboxes[:, :4] * scale_factor if rescale else det_bboxes)
    mask_rois = bbox2roi([_bboxes])
    try:
        mask_feats = self.model.mask_roi_extractor(x[:len(self.model.mask_roi_extractor.featmap_strides)],
                                                   mask_rois)
    except:
        mask_feats = self.model.roi_head.mask_roi_extractor(
            x[:len(self.model.roi_head.mask_roi_extractor.featmap_strides)], mask_rois.to(self.device))

    scores = det_bboxes.detach().cpu().numpy()[:, -1]
    bboxes = _bboxes.detach().cpu().numpy()[:, :4]
    labels = det_labels.detach().cpu().numpy()

改为


rpn_results_list = self.model.rpn_head.predict(x, [img_meta], rescale=False)
results_list = self.model.roi_head.predict(x, rpn_results_list, [img_meta])
batch_data_samples = self.model.add_pred_to_datasample([img_meta], results_list)
det_bboxes = batch_data_samples[0].pred_instances.bboxes.detach().cpu()
scores = batch_data_samples[0].pred_instances.scores.detach().cpu().numpy()
if det_bboxes.shape[0] == 0:
    return []

ori_shape = img_meta.ori_shape
scale_factor = img_meta.scale_factor

_bboxes = det_bboxes
mask_rois = bbox2roi([_bboxes])
mask_feats = self.model.roi_head.mask_roi_extractor(x[:len(self.model.roi_head.mask_roi_extractor.featmap_strides)], mask_rois.to(self.device))
det_labels = batch_data_samples[0].pred_instances.labels.detach().cpu()


调用的时候遇到的问题

遇到的问题:
‘ConfigDict’ object has no attribute ‘nms’

‘ConfigDict’ object has no attribute ‘max_per_img’

将配置文件中的

test_cfg=dict(
       rpn=dict(
           nms_across_levels=False,
           nms_pre=1000,
           nms_post=1000,
           max_num=1000,
           nms_thr=0.7,
           min_bbox_size=0),
       rcnn=dict(
           score_thr=0.05,
           nms=dict(type='nms', iou_threshold=0.5),
           max_per_img=100,
           mask_thr_binary=0.5))

改为:

# 其他有些参数可能失效了
 test_cfg=dict(
        rpn=dict(
            nms_across_levels=False,
            nms_pre=1000,
            nms_post=1000,
            max_per_img=1000,
            max_num=1000,
            nms_thr=0.7,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            score_thr=0.05,
            nms=dict(type='nms', iou_threshold=0.5),
            max_per_img=100,
            mask_thr_binary=0.5))

在这里插入图片描述

文章来源:https://blog.csdn.net/Magicapprentice/article/details/134966537
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。