【论文复现】zoedepth踩坑

2023-12-14 00:13:06

注意模型IO:
保证输入、输出精度、类型与复现目标一致。

模型推理的代码

from torchvision import transforms
def image_to_tensor(img_path, unsqueeze=True):
    rgb = transforms.ToTensor()(Image.open(img_path))
    if unsqueeze:
        rgb = rgb.unsqueeze(0)
    return rgb


def disparity_to_tensor(disp_path, unsqueeze=True):
    disp = cv2.imread(disp_path, -1) / (2 ** 16 - 1)
    # disp = cv2.imread(disp_path, -1) / (2 ** 8 - 1)
    disp = torch.from_numpy(disp)[None, ...]
    if unsqueeze:
        disp = disp.unsqueeze(0)
    return disp.float()

输入图像:uint8 960*1280

# load input
try:
    image = image_to_tensor(img_path).cuda()  # [1,3,h,w]
except:
    image = image_to_tensor(img_path[:-3] + 'jpg').cuda()  # [1,3,h,w]

# image = image.type(torch.float32)/255/255
if image.shape[1] == 1:
    image = torch.tile(image, dims=(1, 3, 1, 1))
# image = image.float() / (2 ** 16 - 1)
image = image[:, 0:3, ...]

使用numpy加载测试,必须归一化到【0,1】

image_np = cv2.imread(input_pic, cv2.IMREAD_GRAYSCALE)

# if len(image_np.shape) == 2:
#     image_np= np.repeat(image_np[:, :, np.newaxis], 3, axis=2)
pic_int = torch.from_numpy(image_np).cuda().unsqueeze(0).unsqueeze(0).float()
if pic_int.shape[1] == 1:
	pic_int = torch.tile(pic_int , dims=(1, 3, 1, 1))
self.zoe(pic_int )

归一化对比

失败

pic_int = pic_int/255

成功

有的数据集中,图像本身就尺度很大。比如保存成16bit的byte格式,读入后:

tensor([[[[13629, 13629, 14012,  ..., 21654, 21017, 20635],
          [13629, 12993, 13247,  ..., 21654, 21781, 21399],
          [12993, 12865, 12738,  ..., 21145, 21017, 21272],
          ...,
          [17196, 17069, 16941,  ..., 21399, 20890, 22291],
          [17069, 17196, 16814,  ..., 21399, 21909, 22291],
          [17196, 17705, 16686,  ..., 21527, 22036, 65535]]]], device='cuda:0',
       dtype=torch.int32)

就需要二次归一化

归一化测试脚本

import torch
import warnings

def check_tensor_values(tensor):
    """
    Check the maximum and minimum values of a tensor.
    Issues a warning if the maximum value is greater than 1 or the minimum value is less than 0.001.

    Parameters:
    tensor (torch.Tensor): The tensor to check.
    """
    max_value = torch.max(tensor)
    min_value = torch.min(tensor)

    # Check for the maximum value condition
    if max_value > 1:
        warnings.warn("The maximum value is greater than 1!", UserWarning)

    # Check for the minimum value condition
    if min_value < 0.001:
        warnings.warn("The minimum value is less than 0.001!", UserWarning)

图像转换16位

from PIL import Image
import numpy as np

# 打开图像
image = Image.open('lutao_exp/thermal/left_thermal_darkpre_ft_0_epoch_out_depth_colored.png')

# 将图像转换为8位灰度图像,然后转换为16位
image_gray = image.convert('L')  # 转换为灰度图像
image_gray_16bit = np.array(image_gray, dtype=np.uint16) * 256  # 转换为16位

# 创建一个新的Pillow图像对象
image_16bit_pillow = Image.fromarray(image_gray_16bit, mode='I;16')

# 保存16位图像
image_16bit_pillow.save('lutao_exp/thermal/left_thermal_darkpre_ft_0_epoch_out_depth_colored16.png')

文章来源:https://blog.csdn.net/prinTao/article/details/134904477
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。