pytorch debug 常用工具
2023-12-13 15:36:31
自动辨识图像格式可视化
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
def convert_to_numpy(image_input):
"""
自动检测输入图像类型,并将其转换为NumPy数组。
"""
if isinstance(image_input, np.ndarray):
# 输入已经是NumPy数组,直接返回
return image_input
elif 'Tensor' in str(type(image_input)):
# 输入是Tensor类型
# 检查是否需要转换(依赖于Tensor所属的库,如PyTorch, TensorFlow等)
if hasattr(image_input, 'detach'):
# 假设是PyTorch Tensor
image_input = image_input.detach().cpu().numpy()
else:
# 假设是TensorFlow Tensor或其他框架的Tensor
image_input = image_input.numpy()
# 如果Tensor有通道维度在最前面(如CHW),则需要转换为HWC
if image_input.ndim == 3 and image_input.shape[0] in (1, 3):
image_input = image_input.transpose(1, 2, 0)
elif isinstance(image_input, Image.Image):
# 输入是Pillow图像,转换为NumPy数组
image_input = np.array(image_input)
else:
raise TypeError("Unsupported image type")
# 如果图像是单通道的,且在最后一个维度(例如HxWx1),去掉该维度
if image_input.ndim == 3 and image_input.shape[-1] == 1:
image_input = image_input.squeeze(-1)
image_np = image_input
if image_np.ndim == 3 and image_np.shape[-1] == 3:
plt.imshow(image_np)
else:
plt.imshow(image_np, cmap='viridis')
plt.title(title)
plt.axis('off')
plt.show()
def visualize_image(image_np, title="Image"):
"""
可视化NumPy格式的图像
"""
if image_np.ndim == 3 and image_np.shape[-1] == 3:
plt.imshow(image_np)
else:
plt.imshow(image_np, cmap='gray')
plt.title(title)
plt.axis('off')
plt.show()
# 示例使用
# image_tensor, image_np, image_pil 分别代表Tensor, NumPy数组, Pillow图像的输入
# 将它们转换为NumPy数组
# image_np = convert_to_numpy(image_tensor)
# image_np = convert_to_numpy(image_np)
# image_np = convert_to_numpy(image_pil)
# # 可视化图像
# visualize_image(image_np)
可视化
张量可视化
import torch
from torchvision.transforms.functional import to_pil_image
from PIL import Image
def tensor_to_pil(tensor):
# 确保tensor是在CPU上
tensor = tensor.cpu()
# 如果tensor有一个批次维度,去除它
if tensor.dim() == 4 and tensor.shape[0] == 1:
tensor = tensor.squeeze(0)
# 转换为PIL图像
pil_image = to_pil_image(tensor)
# 返回PIL图像
return pil_image
tensor_to_pil( ).show()
可视化已经图像信息
def draw_np(pic_np):
pic_np = np.squeeze(pic_np)
plt.imshow(pic_np)
# 隐藏坐标轴
plt.axis('on')
# 显示数据标尺
plt.colorbar()
# 显示图像
plt.show()
def get_image_info(image):
# 获取图像的模式、格式和尺寸
mode = image.mode
format_ = image.format
size = image.size
# 根据图像模式推断每个通道的位数
if mode in ("1", "L", "P"):
bits_per_channel = 8 # 通常是8位
elif mode == "RGB":
bits_per_channel = 8 # 通常是8位,3通道
elif mode == "RGBA":
bits_per_channel = 8 # 通常是8位,4通道
elif mode == "I":
bits_per_channel = 32 # 整数像素模式
elif mode == "F":
bits_per_channel = 32 # 浮点像素模式
else:
bits_per_channel = 'unknown' # 未知或不常见的模式
# 计算总位数
total_bits = image.getbands().__len__() * bits_per_channel
# 打印图像信息
print(f"Image mode: {mode}")
print(f"Image format: {format_}")
print(f"Image size: {size}")
print(f"Bits per channel: {bits_per_channel}")
print(f"Total bits per pixel: {total_bits}")
#%%
import numpy as np
def get_array_info(np_array):
"""
获取并打印NumPy数组的详细信息。
参数:
np_array: NumPy数组。
"""
# 获取数组的形状
shape = np_array.shape
# 获取数组的总元素数量
size = np_array.size
# 获取数组的数据类型
dtype = np_array.dtype
# 获取数组单个元素的大小(以字节为单位)
itemsize = np_array.itemsize
# 获取数组的维度数量
ndim = np_array.ndim
# 获取数组的总字节数
nbytes = np_array.nbytes
# 打印数组信息
print(f"Array Shape: {shape}")
print(f"Array Size: {size}")
print(f"Array Data Type: {dtype}")
print(f"Item Size: {itemsize} bytes")
print(f"Array Dimensions: {ndim}")
print(f"Total Bytes: {nbytes} bytes")
def read_pic(path_pic):
# 加载图像
image = Image.open(path_pic)
print(image.size)
print(image.format)
return image
def pic_to_np(pic):
np_depth = np.array(pic)
return np_depth
def draw_np(pic_np):
pic_np = np.squeeze(pic_np)
plt.imshow(pic_np)
# 隐藏坐标轴
plt.axis('on')
# 显示数据标尺
plt.colorbar()
# 显示图像
plt.show()
def pic_info(path):
raw_image = read_pic(path)
raw_np = pic_to_np(raw_image)
get_image_info(raw_image)
get_array_info(raw_np)
raw_image.show()
draw_np(raw_np)
文章来源:https://blog.csdn.net/prinTao/article/details/134900653
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!