【debug】报错Assertion input_val >= zero && input_val <= one解决

2023-12-14 12:43:53

1.报错信息

????????在用服务器跑模型计算loss时,训练过程中报错,详细报错如下:Assertion?input_val >= zero && input_val <= one?failed.RuntimeError: CUDA error: device-side assert triggered。

../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [164,0,0], thread: [31,0,0] Assertion?input_val >= zero && input_val <= one
../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [10,0,0], thread: [33,0,0] Assertion?input_val >= zero && input_val <= one
../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [14,0,0], thread: [31,0,0] Assertion?input_val >= zero && input_val <= one
../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [88,0,0], thread: [46,0,0] Assertion?input_val >= zero && input_val <= one
../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [86,0,0], thread: [21,0,0] Assertion?input_val >= zero && input_val <= one
../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [64,0,0], thread: [71,0,0] Assertion?input_val >= zero && input_val <= one
../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [43,0,0], thread: [84,0,0] Assertion?input_val >= zero && input_val <= one
../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [166,0,0], thread: [21,0,0] Assertion?input_val >= zero && input_val <= one
../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [198,0,0], thread: [24,0,0] Assertion?input_val >= zero && input_val <= one
../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [144,0,0], thread: [57,0,0] 
Assertion?input_val >= zero && input_val <= one

...

../aten/src/ATen/native/cuda/Loss.cu:118: operator(): block: [34,0,0], thread: [91,0,0] Assertion?input_val >= zero && input_val <= one

2.原因分析

????????因为是在训练过程中,因此肯定不是所有的图像都有错,跑了几次之后,发现是在最后一个batch的训练过程中报错,且报错位置定位到loss计算,那就肯定是最后一个epoch算loss报错了。

????????打印出最后一个batch的输出尺寸,果然,输出维度为(1,2,256,256),按照(b,c,w,h)推理,b=1 那就意味着最后一个epoch的样本数为1,没有被整除导致,比如,3201个样本,batch_size=32,前面每个epoch训练都是32张训练样本,到最后一个epoch,只剩下一个样本了,就会报这个错。

3.解决方案

????????找到问题的源头在哪就好解决了,只需要在dataloader定义的位置设置一个drop_last=True参数,忽略不能整除的最后一个epoch即可。具体代码:

dataloader = DataLoader(dataset=source_dataset,
                                       batch_size=config.batch_size,
                                       shuffle=True,
                                       pin_memory=True,
                                       collate_fn=collate_fn_w_transform,
                                       num_workers=config.num_workers,
                                       drop_last=True)

????????当然,如果你不想浪费那个样本,也可以通过调整batch_size来解决该问题,修改原则就是num_data = epoch*batch_size,保证三者均是整数,即保证epoch能被整除。

整理不易,欢迎一键三连!!!

送你们一条美丽的--分割线--


🌷🌷🍀🍀🌾🌾🍓🍓🍂🍂🙋🙋🐸🐸🙋🙋💖💖🍌🍌🔔🔔🍉🍉🍭🍭🍋🍋🍇🍇🏆🏆📸📸????🍎🍎👍👍🌷🌷

文章来源:https://blog.csdn.net/qq_38308388/article/details/134986983
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。