【代码分析】MPI

2023-12-14 20:42:24

代码解读

MPI

class MPIPredictor(nn.Module):
    def __init__(
        self,
        width=384,
        height=256,
        num_planes=64,
    ):
        super(MPIPredictor, self).__init__()
        self.num_planes = num_planes
        disp_range = [0.001, 1]
        self.far, self.near = disp_range

self.far, self.near = disp_range 规定深度,必须归一化。
特征掩码: 训练过程中动态地隐藏或者“掩盖”数据的一部分,从而使模型能够更加健壮,能够处理不完整的输入数据,或者是为了增强模型对数据的理解。
在使用注意力机制时,特征掩码可以用来防止模型在计算注意力分布时考虑到某些不应该被看见的信息.
在解码器中防止模型“看到”未来的信息。

    def forward(
        self, 
        src_imgs,   # 源图像和相应的深度图
        src_depths, 
    ):  
        # 源图像和深度图下采样到 self.low_res_size 指定的较低分辨率
        # align_corners=True 参数确保下采样后角点对齐
        rgb_low_res = F.interpolate(src_imgs, size=self.low_res_size, mode='bilinear', align_corners=True)
        disp_low_res = F.interpolate(src_depths, size=self.low_res_size, mode='bilinear', align_corners=True)
        
        bs = src_imgs.shape[0]
        # 创建了一个在 self.near 和 self.far 之间的视差线性空间,点数为 self.num_planes + 2。
        # [1:-1] 切片去掉了第一个和最后一个值。这个张量随后被移到和 src_imgs 相同的设备上(为了支持 GPU),
        # 增加了一个批次维度,并重复以匹配批次中的每个元素。
        dpn_input_disparity = torch.linspace(
            self.near, 
            self.far, 
            self.num_planes + 2
        )[1:-1].to(src_imgs.device).unsqueeze(0).repeat(bs, 1)
        
        # 使用准备好的视差张量以及下采样后的图像和深度图来生成渲染的视差
        render_disp = self.dpn(dpn_input_disparity, rgb_low_res, disp_low_res)
        # 使用原始图像、深度和渲染的视差来生成特征掩码
        feature_mask = self.fmn(src_imgs, src_depths, render_disp)
        
        # Encoder forward
        # 处理输入的图像和深度图通过几层或几个块,输出结果被存储在编码器特征列表中
        conv1_out, block1_out, block2_out, block3_out, block4_out = self.encoder(src_imgs, src_depths)
        enc_features = [conv1_out, block1_out, block2_out, block3_out, block4_out]
        # Decoder forward
        # 接收编码器特征和特征掩码作为输入,产生最终的输出
        outputs = self.decoder(enc_features, feature_mask)
        
        return outputs[0], render_disp  # 最终预测的图像或深度图,以及渲染的视差

PAN

平面调整网络
没找到定义方式,PAN.py中没有

CPN

颜色预测网络

编码 、解码、U-Net

DPN

class DepthPredictionNetwork(nn.Module):
    """
    从初始视差(init_disp)、低分辨率RGB图像(rgb_low_res)和低分辨率视差信息(disp_low_res)中预测深度值。
    """
    def __init__(self, disp_range, **kwargs):
        super().__init__()
        # 降低输入数据的维度,同时增加特征维度
        self.context_encoder = DownsizeEncoder(num_blocks=5, dim_in=5, dim_out=128)
        # MultiheadSelfAttention自注意力模块,它专门用于处理序列化的向量
        self.self_attention = MultiheadSelfAttention(num_heads=4, dim_in=128, dim_qk=32, dim_v=128) 
        # 嵌入层embed,它将特征进一步处理为更小的维度(从128降到32),然后是一个ReLU激活函数。
        self.embed = nn.Sequential(
            nn.Linear(128, 32),
            nn.ReLU(),
        )
        # 将嵌入层的输出进一步处理并映射到视差值的范围内
        self.to_disp = LinearSigmoid(32, disp_range)

    def forward(self, init_disp, rgb_low_res, disp_low_res):
        B, S = init_disp.shape
        
        # context encoder
        # 将RGB图像、低分辨率的视差信息和初始视差拼接在一起,形成一个有5个通道的输入
        x = torch.cat([
                rgb_low_res[:, None, ...].repeat(1, S, 1, 1, 1), 
                disp_low_res[:, None, ...].repeat(1, S, 1, 1, 1), 
                init_disp[:, :, None, None, None].repeat(1, 1, 1, *rgb_low_res.shape[-2:])
            ], dim=-3)   # [b, s, 5, h/4, w/4]
        
        x = x.view(-1, *x.shape[-3:])  # [b*s, 5, h/4, w/4]
        # 降低输入数据的维度,同时增加特征维度
        context = self.context_encoder(x)   # [b*s, c, h/128, w/128]
        # 自适应平均池化(F.adaptive_avg_pool2d)将特征降维,并压缩到一个向量。
        context = F.adaptive_avg_pool2d(context, (1, 1)).squeeze(-1).squeeze(-1)  # [b*s, c]
        context = context.view(B, S, -1)  # [b, s, c]
        
        # self attention
        # 自注意力处理,以捕获不同初始视差之间的关系
        feat_atted = self.self_attention(context)   # [b, s, c ]
        # 更小的维度
        feat = self.embed(feat_atted)  # [b, s, c]
        # 特征映射为最终的视差预测值
        disp_bs = self.to_disp(feat, init_disp)  # [b, s]
        return disp_bs

FMN

UNET

class FeatMaskNetwork(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.conv1 = ConvBNReLU(5, 16, 3, 1, 1)
        self.conv2 = ConvBNReLU(16, 32, 3, 2, 1)
        self.conv3 = ConvBNReLU(32, 64, 3, 2, 1)
        self.conv4 = ConvBNReLU(64, 128, 3, 2, 1)
        self.conv5 = ConvBNReLU(128, 128, 3, 1, 1)
        self.conv6 = ConvBNReLU(192, 64, 3, 1, 1)
        self.conv7 = ConvBNReLU(96, 32, 3, 1, 1)
        self.conv8 = ConvBNReLU(48, 16, 3, 1, 1)
        self.conv9 = ConvBNReLU(16, 1, 3, 1, 1)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)    # 上采样层

    def forward(self, input_image, input_depth, input_mpi_disparity):
        '''
        input_image: [b,3,h,w]
        input_depth: [b,1,h,w]      深度数据
        input_mpi_disparity: [b,s]  多平面图像(MPI)视差
        '''

        """
        将 input_image、input_depth 和 input_mpi_disparity 扩展到相同的维度,使得它们能够在空间维度上与每个MPI平面对齐。
        """
        _, _, h, w = input_image.size()  # spatial dim
        b, s = input_mpi_disparity.size()  # number of mpi planes
        # repeat input rgb
        expanded_image = input_image.unsqueeze(1).repeat(1, s, 1, 1, 1)  # [b,s,3,h,w]
        # repeat input depth
        expanded_depth = input_depth.unsqueeze(1).repeat(1, s, 1, 1, 1)  # [b,s,1,h,w]
        # repeat and reshape input mpi disparity
        expanded_mpi_disp = input_mpi_disparity[:, :, None, None, None].repeat(1, 1, 1, h, w)  # [b,s,1,h,w]

        # concat together
        # 将这些扩展后的数据沿通道维度拼接,并重塑成 [bs,5,h,w] 的形状
        x = torch.cat([expanded_image, expanded_depth, expanded_mpi_disp], dim=2).reshape(b * s, 5, h, w)  # [bs,5,h,w]
        
        # forward
        c1 = self.conv1(x)
        c2 = self.conv2(c1)
        c3 = self.conv3(c2)
        c4 = self.conv4(c3)
        c5 = self.conv5(c4)
        u5 = self.upsample(c5)
        c6 = self.conv6(torch.cat([u5, c3], dim=1))
        u6 = self.upsample(c6)
        c7 = self.conv7(torch.cat([u6, c2], dim=1))
        u7 = self.upsample(c7)
        c8 = self.conv8(torch.cat([u7, c1], dim=1))
        c9 = self.conv9(c8)  # [bs,1,h,w]
        fm = c9.reshape(b, s, h, w)
        # 特征蒙版经过 softmax 函数处理,以获取每个MPI平面上每个像素位置的归一化权重。
        # 对每个MPI平面上的特征贡献的概率分布,这在合成新视角图像或进行深度估计时非常有用。
        # 通过 softmax 保证了每个像素位置上所有平面权重的和为1,使其可以被解释为概率
        fm = torch.softmax(fm ,dim=1)

        return fm
        

文章来源:https://blog.csdn.net/prinTao/article/details/135002303
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。