I3D代码讲解
2023-12-20 18:15:03
I3D网络是视频处理的时候经常使用的网络,最后输出分类的分数和概率。
网络结构
1.网络整体框架
2.网络整体框架中的 Inc. 模块具体结构
代码实现
1.关于Inception Module的实现,就是 inc. 模块
具体实现如下:b0、b1、b2、b3对应上图中的四个分支,最后在前向传播中用concat将他们拼接在一起。
class InceptionModule(nn.Module):
def __init__(self, in_channels, out_channels, name):
super(InceptionModule, self).__init__()
self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0,
name=name+'/Branch_0/Conv3d_0a_1x1')
self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0,
name=name+'/Branch_1/Conv3d_0a_1x1')
self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3],
name=name+'/Branch_1/Conv3d_0b_3x3')
self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0,
name=name+'/Branch_2/Conv3d_0a_1x1')
self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3],
name=name+'/Branch_2/Conv3d_0b_3x3')
self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3],
stride=(1, 1, 1), padding=0)
self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0,
name=name+'/Branch_3/Conv3d_0b_1x1')
self.name = name
def forward(self, x):
print(x.shape)
b0 = self.b0(x)
b1 = self.b1b(self.b1a(x))
b2 = self.b2b(self.b2a(x))
b3 = self.b3b(self.b3a(x))
return torch.cat([b0,b1,b2,b3], dim=1)
2.整体网络的实现
class InceptionI3d(nn.Module):
VALID_ENDPOINTS = (
'Conv3d_1a_7x7',
'MaxPool3d_2a_3x3',
'Conv3d_2b_1x1',
'Conv3d_2c_3x3',
'MaxPool3d_3a_3x3',
'Mixed_3b',
'Mixed_3c',
'MaxPool3d_4a_3x3',
'Mixed_4b',
'Mixed_4c',
'Mixed_4d',
'Mixed_4e',
'Mixed_4f',
'MaxPool3d_5a_2x2',
'Mixed_5b',
'Mixed_5c',
'Logits',
'Predictions',
)#网络各层的名字
def __init__(self, num_classes=400, spatial_squeeze=True,
final_endpoint='Logits', name='inception_i3d', in_channels=3, dropout_keep_prob=0.5):
#final_endpoint:选择要在哪层结束然后输出结构
#如果final_endpoint不是网络中出现的层,那么说明这个网络层的名字是未知的
if final_endpoint not in self.VALID_ENDPOINTS:
raise ValueError('Unknown final endpoint %s' % final_endpoint)
super(InceptionI3d, self).__init__()
self._num_classes = num_classes
self._spatial_squeeze = spatial_squeeze
self._final_endpoint = final_endpoint
self.logits = None
if self._final_endpoint not in self.VALID_ENDPOINTS:
raise ValueError('Unknown final endpoint %s' % self._final_endpoint)
self.end_points = {}#存放网络名字和相应的层
end_point = 'Conv3d_1a_7x7'#网络层的名字,对应网络整体框架中的第一个模块,以下依次类推
self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7],
stride=(2, 2, 2), name=name+end_point)
if self._final_endpoint == end_point: return#如果该层是指定的最后一层,那么直接返回。
end_point = 'MaxPool3d_2a_3x3'
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
padding=0)
if self._final_endpoint == end_point: return
end_point = 'Conv3d_2b_1x1'
self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0,
name=name+end_point)
if self._final_endpoint == end_point: return
end_point = 'Conv3d_2c_3x3'
self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1,
name=name+end_point)
if self._final_endpoint == end_point: return
end_point = 'MaxPool3d_3a_3x3'
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
padding=0)
if self._final_endpoint == end_point: return
end_point = 'Mixed_3b'
self.end_points[end_point] = InceptionModule(192, [64,96,128,16,32,32], name+end_point)
#这里InceptionModule中输出直接指定每个分支要输出的通道数,最后输出通道数为64+128+32+32,
#因为第一个分支直接指定输出通道,第二三个分支都有两个模块,所以要分别指出这两个模块的输出通道,
#但是只有最后一个模块的输出通道才是分支最终的输出通道,第四个模块经过池化通道数不变,所以只需要
#指出一个输出通道即可。以下依次类推。
if self._final_endpoint == end_point: return
end_point = 'Mixed_3c'
self.end_points[end_point] = InceptionModule(256, [128,128,192,32,96,64], name+end_point)
if self._final_endpoint == end_point: return
end_point = 'MaxPool3d_4a_3x3'
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2),
padding=0)
if self._final_endpoint == end_point: return
end_point = 'Mixed_4b'
self.end_points[end_point] = InceptionModule(128+192+96+64, [192,96,208,16,48,64], name+end_point)
if self._final_endpoint == end_point: return
end_point = 'Mixed_4c'
self.end_points[end_point] = InceptionModule(192+208+48+64, [160,112,224,24,64,64], name+end_point)
if self._final_endpoint == end_point: return
end_point = 'Mixed_4d'
self.end_points[end_point] = InceptionModule(160+224+64+64, [128,128,256,24,64,64], name+end_point)
if self._final_endpoint == end_point: return
end_point = 'Mixed_4e'
self.end_points[end_point] = InceptionModule(128+256+64+64, [112,144,288,32,64,64], name+end_point)
if self._final_endpoint == end_point: return
end_point = 'Mixed_4f'
self.end_points[end_point] = InceptionModule(112+288+64+64, [256,160,320,32,128,128], name+end_point)
if self._final_endpoint == end_point: return
end_point = 'MaxPool3d_5a_2x2'
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2),
padding=0)
if self._final_endpoint == end_point: return
end_point = 'Mixed_5b'
self.end_points[end_point] = InceptionModule(256+320+128+128, [256,160,320,32,128,128], name+end_point)
if self._final_endpoint == end_point: return
end_point = 'Mixed_5c'
self.end_points[end_point] = InceptionModule(256+320+128+128, [384,192,384,48,128,128], name+end_point)
if self._final_endpoint == end_point: return
end_point = 'Logits'
self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7],
stride=(1, 1, 1))
self.dropout = nn.Dropout(dropout_keep_prob)
self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,
kernel_shape=[1, 1, 1],
padding=0,
activation_fn=None,
use_batch_norm=False,
use_bias=True,
name='Logits')
self.build()
#---------------------------------------#
# 当训练我们自己的数据集的时候,可以使用
# replace_logits来定义自己数据集中的类别数
#---------------------------------------#
def replace_logits(self, num_classes):
self._num_classes = num_classes
self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,
kernel_shape=[1, 1, 1],
padding=0,
activation_fn=None,
use_batch_norm=False,
use_bias=True,
name='logits')
#------------------------------------#
# 在上面的初始化函数中,我们已经将需要使用的
# 模块全部定义到end_points这个字典里面
# 这个函数的目的在于将这些模块全部加入至
# module里面进而方便后续的调用
#------------------------------------#
def build(self):
for k in self.end_points.keys():
self.add_module(k, self.end_points[k])
def forward(self, x):
for end_point in self.VALID_ENDPOINTS:
if end_point in self.end_points:
x = self._modules[end_point](x)
x = self.logits(self.dropout(self.avg_pool(x)))
if self._spatial_squeeze:
logits = x.squeeze(3).squeeze(3).squeeze(0)#这里是对尺度为1的通道进行挤压
# logits is batch X time X classes, which is what we want to work with
return logits
#------------------------------------#
# 只提取图像特征,并再最后过一层3d平均池化
# 得到最终的特征图像
#------------------------------------#
def extract_features(self, x):
for end_point in self.VALID_ENDPOINTS:
if end_point in self.end_points:
x = self._modules[end_point](x)
return self.avg_pool(x)
引用:代码参考处
输入:
通常是.npy文件形式的帧序列或是光流序列
文章来源:https://blog.csdn.net/weixin_45486992/article/details/135109211
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!