【debug】报错Assertion input_val >= zero && input_val <= one解决
2023-12-14 12:43:53
1.报错信息
????????在用服务器跑模型计算loss时,训练过程中报错,详细报错如下:Assertion?input_val >= zero && input_val <= one
?failed.和 RuntimeError: CUDA error: device-side assert triggered。
../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [164,0,0], thread: [31,0,0] Assertion?input_val >= zero && input_val <= one
../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [10,0,0], thread: [33,0,0] Assertion?input_val >= zero && input_val <= one
../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [14,0,0], thread: [31,0,0] Assertion?input_val >= zero && input_val <= one
../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [88,0,0], thread: [46,0,0] Assertion?input_val >= zero && input_val <= one
../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [86,0,0], thread: [21,0,0] Assertion?input_val >= zero && input_val <= one
../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [64,0,0], thread: [71,0,0] Assertion?input_val >= zero && input_val <= one
../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [43,0,0], thread: [84,0,0] Assertion?input_val >= zero && input_val <= one
../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [166,0,0], thread: [21,0,0] Assertion?input_val >= zero && input_val <= one
../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [198,0,0], thread: [24,0,0] Assertion?input_val >= zero && input_val <= one
../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [144,0,0], thread: [57,0,0]
Assertion?input_val >= zero && input_val <= one
...
../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [34,0,0], thread: [91,0,0] Assertion?input_val >= zero && input_val <= one
2.原因分析
????????因为是在训练过程中,因此肯定不是所有的图像都有错,跑了几次之后,发现是在最后一个batch的训练过程中报错,且报错位置定位到loss计算,那就肯定是最后一个epoch算loss报错了。
????????打印出最后一个batch的输出尺寸,果然,输出维度为(1,2,256,256),按照(b,c,w,h)推理,b=1 那就意味着最后一个epoch的样本数为1,没有被整除导致,比如,3201个样本,batch_size=32,前面每个epoch训练都是32张训练样本,到最后一个epoch,只剩下一个样本了,就会报这个错。
3.解决方案
????????找到问题的源头在哪就好解决了,只需要在dataloader定义的位置设置一个drop_last=True参数,忽略不能整除的最后一个epoch即可。具体代码:
dataloader = DataLoader(dataset=source_dataset,
batch_size=config.batch_size,
shuffle=True,
pin_memory=True,
collate_fn=collate_fn_w_transform,
num_workers=config.num_workers,
drop_last=True)
????????当然,如果你不想浪费那个样本,也可以通过调整batch_size来解决该问题,修改原则就是num_data = epoch*batch_size,保证三者均是整数,即保证epoch能被整除。
整理不易,欢迎一键三连!!!
送你们一条美丽的--分割线--
🌷🌷🍀🍀🌾🌾🍓🍓🍂🍂🙋🙋🐸🐸🙋🙋💖💖🍌🍌🔔🔔🍉🍉🍭🍭🍋🍋🍇🍇🏆🏆📸📸????🍎🍎👍👍🌷🌷
文章来源:https://blog.csdn.net/qq_38308388/article/details/134986983
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!