yolov5知识蒸馏

2023-12-25 19:33:35

参考代码:https://github.com/Adlik/yolov5
https://cloud.tencent.com/developer/article/2160509
yolov5间的模型蒸馏,相同结构的。
配置参数

parser.add_argument('--t_weights', type=str, default='./weights/yolov5s.pt',
                        help='initial teacher model weights path')
parser.add_argument('--t_cfg', type=str, default='models/yolov5s.yaml', help='teacher model.yaml path')
parser.add_argument('--d_output', action='store_true', default=False,
                    help='if true, only distill outputs')
parser.add_argument('--d_feature', action='store_true', default=False,
                    help='if true, distill both feature and output layers')

加载教师模型

Model

check_suffix(weights, '.pt')  # check weights
pretrained = weights.endswith('.pt')
if pretrained:
    with torch_distributed_zero_first(LOCAL_RANK):
        weights = attempt_download(weights)  # download if not found locally
    ckpt = torch.load(weights, map_location=device)  # load checkpoint
    model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create
    exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else []  # exclude keys
    csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
    csd = intersect_dicts(csd, model.state_dict(), exclude=exclude)  # intersect
    model.load_state_dict(csd, strict=False)  # load
    LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}')  # report

	# 这里添加加载教师模型
    # Teacher model
    LOGGER.info(f'Loaded teacher model {t_cfg}')  # report
    t_ckpt = torch.load(t_weights, map_location=device)  # load checkpoint
    t_model = Model(t_cfg or t_ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)
    exclude = ['anchor'] if (t_cfg or hyp.get('anchors')) and not resume else []  # exclude keys
    csd = t_ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
    csd = intersect_dicts(csd, t_model.state_dict(), exclude=exclude)  # intersect
    t_model.load_state_dict(csd, strict=False)  # load

损失函数:

   s_loss, loss_items = compute_loss(pred, targets.to(device))  # loss scaled by batch_size

    d_outputs_loss = compute_distillation_output_loss(pred, t_pred, model, d_weight=10)

    if opt.d_feature:
        d_feature_loss = compute_distillation_feature_loss(s_f, t_f, model, f_weight=0.1)
        loss = d_outputs_loss + s_loss + d_feature_loss
    else:
        loss = d_outputs_loss + s_loss

文章来源:https://blog.csdn.net/weixin_41012399/article/details/135205474
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。