【论文复现】zoedepth踩坑
2023-12-14 00:13:06
注意模型IO:
保证输入、输出精度、类型与复现目标一致。
模型推理的代码
from torchvision import transforms
def image_to_tensor(img_path, unsqueeze=True):
rgb = transforms.ToTensor()(Image.open(img_path))
if unsqueeze:
rgb = rgb.unsqueeze(0)
return rgb
def disparity_to_tensor(disp_path, unsqueeze=True):
disp = cv2.imread(disp_path, -1) / (2 ** 16 - 1)
# disp = cv2.imread(disp_path, -1) / (2 ** 8 - 1)
disp = torch.from_numpy(disp)[None, ...]
if unsqueeze:
disp = disp.unsqueeze(0)
return disp.float()
输入图像:uint8 960*1280
# load input
try:
image = image_to_tensor(img_path).cuda() # [1,3,h,w]
except:
image = image_to_tensor(img_path[:-3] + 'jpg').cuda() # [1,3,h,w]
# image = image.type(torch.float32)/255/255
if image.shape[1] == 1:
image = torch.tile(image, dims=(1, 3, 1, 1))
# image = image.float() / (2 ** 16 - 1)
image = image[:, 0:3, ...]
使用numpy加载测试,必须归一化到【0,1】
image_np = cv2.imread(input_pic, cv2.IMREAD_GRAYSCALE)
# if len(image_np.shape) == 2:
# image_np= np.repeat(image_np[:, :, np.newaxis], 3, axis=2)
pic_int = torch.from_numpy(image_np).cuda().unsqueeze(0).unsqueeze(0).float()
if pic_int.shape[1] == 1:
pic_int = torch.tile(pic_int , dims=(1, 3, 1, 1))
self.zoe(pic_int )
失败
pic_int = pic_int/255
成功
有的数据集中,图像本身就尺度很大。比如保存成16bit的byte格式,读入后:
tensor([[[[13629, 13629, 14012, ..., 21654, 21017, 20635],
[13629, 12993, 13247, ..., 21654, 21781, 21399],
[12993, 12865, 12738, ..., 21145, 21017, 21272],
...,
[17196, 17069, 16941, ..., 21399, 20890, 22291],
[17069, 17196, 16814, ..., 21399, 21909, 22291],
[17196, 17705, 16686, ..., 21527, 22036, 65535]]]], device='cuda:0',
dtype=torch.int32)
就需要二次归一化
归一化测试脚本
import torch
import warnings
def check_tensor_values(tensor):
"""
Check the maximum and minimum values of a tensor.
Issues a warning if the maximum value is greater than 1 or the minimum value is less than 0.001.
Parameters:
tensor (torch.Tensor): The tensor to check.
"""
max_value = torch.max(tensor)
min_value = torch.min(tensor)
# Check for the maximum value condition
if max_value > 1:
warnings.warn("The maximum value is greater than 1!", UserWarning)
# Check for the minimum value condition
if min_value < 0.001:
warnings.warn("The minimum value is less than 0.001!", UserWarning)
图像转换16位
from PIL import Image
import numpy as np
# 打开图像
image = Image.open('lutao_exp/thermal/left_thermal_darkpre_ft_0_epoch_out_depth_colored.png')
# 将图像转换为8位灰度图像,然后转换为16位
image_gray = image.convert('L') # 转换为灰度图像
image_gray_16bit = np.array(image_gray, dtype=np.uint16) * 256 # 转换为16位
# 创建一个新的Pillow图像对象
image_16bit_pillow = Image.fromarray(image_gray_16bit, mode='I;16')
# 保存16位图像
image_16bit_pillow.save('lutao_exp/thermal/left_thermal_darkpre_ft_0_epoch_out_depth_colored16.png')
文章来源:https://blog.csdn.net/prinTao/article/details/134904477
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!