mmdet 从2.x 模型代码迁移到3.x 记录(以mask rcnn 为例)
2023-12-13 13:26:19
mmdet 2.x 迁移到 3.x
mmdet 2.x 迁移到 3.x
)
step1 根据官方文档修改 配置文件:
https://mmdetection.readthedocs.io/zh-cn/latest/migration/config_migration.html
step2 修改自定义的dataset模块
try:
# mmdet 2.x
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.pipelines import Compose
from mmdet.datasets.custom import CustomDataset
from mmcv.utils import print_log
except Exception as e:
# mmdet 3.x
from mmdet.registry import DATASETS
from mmcv.transforms import Compose
from mmengine.dataset import BaseDataset as CustomDataset
from mmengine.logging import print_log
@DATASETS.register_module()
class YourCustomDataset(CustomDataset):
# 此处省略 部分代码
#
# 将 类初始化中的参数 img_prefix 改为 data_prefix='', test_mode=True,增加:*args,**kwargs
# 类初始化的时候 __init__ 增加
self._metainfo = {}
self._metainfo['classes'] = self.CLASSES
# 增加一个方法
@classmethod
def get_classes(cls, classes=None):
"""Get class names of current dataset.
Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
Returns:
tuple[str] or list[str]: Names of categories of the dataset.
"""
if classes is None:
return cls.CLASSES
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
raise ValueError(f'Unsupported type {type(classes)} of classes.')
return class_names
step3 在模型调用时修改调用的代码
try:
# mmdet 2.x
from mmcv.parallel import collate, scatter
from mmdet.datasets.pipelines import Compose
except Exception as e:
# mmdet 3.x
print(mmcv.__version__,dir(mmcv))
from mmcv.transforms import Compose
from mmdet.structures.bbox import bbox2roi
# from mmcv.parallel import collate, scatter 新版本已经删除
数据处理中的
cfg = self.model.cfg
device = next(self.model.parameters()).device
print("当前预测使用device{}".format(device))
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
test_pipeline = Compose(test_pipeline)
data = dict(img=image)
data = test_pipeline(data)
try:
# mmdet 2.x
if device == "cpu":
data = scatter(collate([data], samples_per_gpu=1), [device])[0]
else:
for m in self.model.modules():
assert not isinstance(
m, RoIPool
), 'CPU inference with RoIPool is not supported currently.'
data = scatter(collate([data], samples_per_gpu=1), [device])[0]
try:
img_meta = data["img_meta"][0] # version 1.x
except Exception as e:
img_meta = data["img_metas"][0]
except Exception as e:
#
print("mmdet的版本是 3.x", e)
print(data["inputs"].shape)
from torchvision import transforms as tfs
normalize = tfs.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
print(type(data["data_samples"]), data["data_samples"].keys())
image_size = data["data_samples"].img_shape
img_aug = tfs.Compose([tfs.Resize(image_size), tfs.CenterCrop(image_size), tfs.ToTensor(), normalize])
data['inputs'] = [data['inputs']]
from PIL import Image
data["img"] = [torch.unsqueeze(img_aug(Image.fromarray(image)).to(device), 0)]
data['data_samples'] = [data['data_samples']]
data["img_meta"] = data["data_samples"]
img_meta = data["img_meta"][0]
img = data["img"][0]
proposals = None
rescale = True
with torch.no_grad():
x = self.model.extract_feat(img)
将调用中的
proposal_list = self.model.rpn_head.simple_test_rpn(x, img_meta)
# pdb.set_trace()
det_bboxes, det_labels = self.model.roi_head.simple_test_bboxes(
x, img_meta, proposal_list, self.model.test_cfg.rcnn, rescale=rescale) # mmdet version 2.x
det_bboxes = det_bboxes[0]
det_labels = det_labels[0]
if "cuda" in str(self.device):
det_bboxes = det_bboxes.detach().cpu()
det_labels = det_labels.detach().cpu()
ori_shape = img_meta[0]['ori_shape']
scale_factor = img_meta[0]['scale_factor']
if det_bboxes.shape[0] == 0:
return []
else:
# if det_bboxes is rescaled to the original image size, we need to
# rescale it back to the testing scale to obtain RoIs.
_bboxes = (det_bboxes[:, :4] * scale_factor if rescale else det_bboxes)
mask_rois = bbox2roi([_bboxes])
try:
mask_feats = self.model.mask_roi_extractor(x[:len(self.model.mask_roi_extractor.featmap_strides)],
mask_rois)
except:
mask_feats = self.model.roi_head.mask_roi_extractor(
x[:len(self.model.roi_head.mask_roi_extractor.featmap_strides)], mask_rois.to(self.device))
scores = det_bboxes.detach().cpu().numpy()[:, -1]
bboxes = _bboxes.detach().cpu().numpy()[:, :4]
labels = det_labels.detach().cpu().numpy()
改为
rpn_results_list = self.model.rpn_head.predict(x, [img_meta], rescale=False)
results_list = self.model.roi_head.predict(x, rpn_results_list, [img_meta])
batch_data_samples = self.model.add_pred_to_datasample([img_meta], results_list)
det_bboxes = batch_data_samples[0].pred_instances.bboxes.detach().cpu()
scores = batch_data_samples[0].pred_instances.scores.detach().cpu().numpy()
if det_bboxes.shape[0] == 0:
return []
ori_shape = img_meta.ori_shape
scale_factor = img_meta.scale_factor
_bboxes = det_bboxes
mask_rois = bbox2roi([_bboxes])
mask_feats = self.model.roi_head.mask_roi_extractor(x[:len(self.model.roi_head.mask_roi_extractor.featmap_strides)], mask_rois.to(self.device))
det_labels = batch_data_samples[0].pred_instances.labels.detach().cpu()
调用的时候遇到的问题
遇到的问题:
‘ConfigDict’ object has no attribute ‘nms’
‘ConfigDict’ object has no attribute ‘max_per_img’
将配置文件中的
test_cfg=dict(
rpn=dict(
nms_across_levels=False,
nms_pre=1000,
nms_post=1000,
max_num=1000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100,
mask_thr_binary=0.5))
改为:
# 其他有些参数可能失效了
test_cfg=dict(
rpn=dict(
nms_across_levels=False,
nms_pre=1000,
nms_post=1000,
max_per_img=1000,
max_num=1000,
nms_thr=0.7,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100,
mask_thr_binary=0.5))
文章来源:https://blog.csdn.net/Magicapprentice/article/details/134966537
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!