coco数据集格式的RandomCrop

2023-12-13 13:37:38

transforms.py文件的改进

添加 RandomCrop 函数

class RandomCrop(object):
    """随机裁剪图像以及bboxes"""
    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, image, target):
        height, width = image.shape[-2:]
        th = self.output_size
        tw = self.output_size

        if width == tw and height == th:
            return image, target

        x = random.randint(0, width - tw)
        y = random.randint(0, height - th)

        image = image[:, y:y+th, x:x+tw]

        bbox = target["boxes"]
        bbox[:, [0, 2]] = bbox[:, [0, 2]] - x
        bbox[:, [1, 3]] = bbox[:, [1, 3]] - y
        target["boxes"] = bbox

        if "masks" in target:
            target["masks"] = target["masks"][:, y:y+th, x:x+tw]

        return image, target

train.py文件中的改进

添加RandomCrop模块

    data_transform = {
        "train": transforms.Compose([transforms.ToTensor(),
                                     transforms.RandomHorizontalFlip(0.5),
                                      transforms.RandomCrop(1024)
                                    ]),
        "val": transforms.Compose([transforms.ToTensor()])
    }

训练中出现错误:

loss达到了50.0+

训练中途loss超过100的的时候会出现 loss is nan的报错。

文章来源:https://blog.csdn.net/llf000000/article/details/134946940
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。