TrustGeo代码理解(三)model.py

2023-12-16 18:30:01

代码链接:https://github.com/ICDM-UESTC/TrustGeo

一、导入各种模块和神经网络类

from math import gamma
from re import L
from .layers import *
import torch
import torch.nn as nn
import torch.nn.functional as Func
import numpy as np

这段代码是一个 Python 模块,包含了一些导入语句和定义了一个神经网络模型的类。

1、from math import gamma:导入了 gamma 函数,这是 Python 标准库中 math 模块中的一个函数,用于计算伽玛函数。
2、from re import L:导入了 L,这看起来是一个导入错误。通常来说,应该是导入正则表达式相关的模块,比如 import re。不过,这行可能是一个错误,可能需要修改。(好像没什么用)
3、from .layers import *:导入了当前模块所在目录中的 layers 模块中的所有内容。* 表示导入所有的内容。
4、import torch:导入了 PyTorch 库中的相关模块。torch 是主要的 PyTorch 模块。
5、import torch.nn as nn:导入了 PyTorch 库中的相关模块。tnn 包含了神经网络的构建块。
6、import torch.nn.functional as Func:导入了 PyTorch 库中的相关模块。functional 模块包含了一些与神经网络相关的函数。
7、import numpy as np:导入了 NumPy 库,NumPy 是一个用于科学计算的 Python 库,提供了大量用于数组操作的函数。

二、TrustGeo类定义(NN模型)

class TrustGeo(nn.Module):
    def __init__(self, dim_in):
        super(TrustGeo, self).__init__()
        self.dim_in = dim_in
        self.dim_z = dim_in + 2

        # TrustGeo
        self.att_attribute = SimpleAttention(temperature=self.dim_z ** 0.5,
                                             d_q_in=self.dim_in,
                                             d_k_in=self.dim_in,
                                             d_v_in=self.dim_in + 2,
                                             d_q_out=self.dim_z,
                                             d_k_out=self.dim_z,
                                             d_v_out=self.dim_z)


        # calculate A
        self.gamma_1 = nn.Parameter(torch.ones(1, 1))
        self.gamma_2 = nn.Parameter(torch.ones(1, 1))
        self.gamma_3 = nn.Parameter(torch.ones(1, 1))
        self.alpha = nn.Parameter(torch.ones(1, 1))
        self.beta = nn.Parameter(torch.zeros(1, 1))

        # transform in Graph
        self.w_1 = nn.Linear(self.dim_in + 2, self.dim_in + 2)
        self.w_2 = nn.Linear(self.dim_in + 2, self.dim_in + 2)


        # higher-order evidence
        # graph view 
        self.out_layer_graph_view = nn.Linear(self.dim_z*2, 5)
        # attribute view 
        self.out_layer_attri_view = nn.Linear(self.dim_in, 5)
    

    # for output mu, v, alpha, beta
    def evidence(self, x):
        return Func.softplus(x)

    def trans(self, gamma1, gamma2, logv, logalpha, logbeta):
        v = self.evidence(logv)
        alpha = self.evidence(logalpha) + 1
        beta = self.evidence(logbeta)
        return gamma1, gamma2, v, alpha, beta
    

    def forward(self, lm_X, lm_Y, tg_X, tg_Y, lm_delay, tg_delay, add_noise=0):
        """
        :param lm_X: feature of landmarks [..., 30]: 14 attribute + 16 measurement
        :param lm_Y: location of landmarks [..., 2]: longitude + latitude
        :param tg_X: feature of targets [..., 30]
        :param tg_Y: location of targets [..., 2]
        :param lm_delay: delay from landmark to the common router [..., 1]
        :param tg_delay: delay from target to the common router [..., 1]
        :return:
        """

  

        N1 = lm_Y.size(0)
        N2 = tg_Y.size(0)
        ones = torch.ones(N1 + N2 + 1).cuda()
        lm_feature = torch.cat((lm_X, lm_Y), dim=1)
        tg_feature_0 = torch.cat((tg_X, torch.zeros(N2, 2).cuda()), dim=1)
        router_0 = torch.mean(lm_feature, dim=0, keepdim=True)
        all_feature_0 = torch.cat((lm_feature, tg_feature_0, router_0), dim=0)

        '''
        star-GNN
        properties:
        1. single directed graph: feature of <landmarks> will never be updated.
        2. the target IP will receive from surrounding landmarks from two ways: 
            (1) attribute similarity-based one-hop propagation;
            (2) delay measurement-based two-hop propagation via the common router;
        '''
        # GNN-step 1
        adj_matrix_0 = torch.diag(ones)

        # star connections (measurement)
        delay_score = torch.exp(-self.gamma_1 * (self.alpha * lm_delay + self.beta))

        rou2tar_score_0 = torch.exp(-self.gamma_2 * (self.alpha * tg_delay + self.beta)).reshape(N2)

        # satellite connections (feature)
        _, attribute_score = self.att_attribute(tg_X, lm_X, lm_feature)
        attribute_score = torch.exp(attribute_score)

        adj_matrix_0[N1:N1 + N2, :N1] = attribute_score
        adj_matrix_0[-1, :N1] = delay_score
        adj_matrix_0[N1:N1 + N2:, -1] = rou2tar_score_0

        degree_0 = torch.sum(adj_matrix_0, dim=1)
        degree_reverse_0 = 1.0 / degree_0
        degree_matrix_reverse_0 = torch.diag(degree_reverse_0)

        degree_mul_adj_0 = degree_matrix_reverse_0 @ adj_matrix_0
        step_1_all_feature = self.w_1(degree_mul_adj_0 @ all_feature_0)

        tg_feature_1 = step_1_all_feature[N1:N1 + N2, :]
        router_1 = step_1_all_feature[-1, :].reshape(1, -1)

        # GNN-step 2
        adj_matrix_1 = torch.diag(ones)
        rou2tar_score_1 = torch.exp(-self.gamma_3 * (self.alpha * tg_delay + self.beta)).reshape(N2)
        adj_matrix_1[N1:N1 + N2:, -1] = rou2tar_score_1

        all_feature_1 = torch.cat((lm_feature, tg_feature_1, router_1), dim=0)

        degree_1 = torch.sum(adj_matrix_1, dim=1)
        degree_reverse_1 = 1.0 / degree_1
        degree_matrix_reverse_1 = torch.diag(degree_reverse_1)

        degree_mul_adj_1 = degree_matrix_reverse_1 @ adj_matrix_1
        step_2_all_feature = self.w_2(degree_mul_adj_1 @ all_feature_1)
        tg_feature_2 = step_2_all_feature[N1:N1 + N2, :]

        # graph view
        tg_feature_graph_view = torch.cat((
                                      tg_feature_1,
                                      tg_feature_2), dim=-1)
        # attribute view (for shanghai dim=51) 
        tg_feature_attribute_view = tg_X
        
        '''
        predict
        '''
        output1 = self.out_layer_graph_view(tg_feature_graph_view)
        gamma1_g, gamma2_g, v_g, alpha_g, beta_g = torch.split(output1, 1, dim=-1)
        # attribute
        output2 = self.out_layer_attri_view(tg_feature_attribute_view)
        gamma1_a, gamma2_a, v_a, alpha_a, beta_a = torch.split(output2, 1, dim=-1)
    
        # transform, let v>0, aplha>1, beta>0 
        gamma1_g, gamma2_g, v_g, alpha_g, beta_g = self.trans(gamma1_g, gamma2_g, v_g, alpha_g, beta_g)
        gamma1_a, gamma2_a, v_a, alpha_a, beta_a = self.trans(gamma1_a, gamma2_a, v_a, alpha_a, beta_a)
        
        two_gamma_g = torch.cat((gamma1_g, gamma2_g), dim=1)
        two_gamma_a = torch.cat((gamma1_a, gamma2_a), dim=1)
        
        return two_gamma_g, v_g, alpha_g, beta_g, \
               two_gamma_a, v_a, alpha_a, beta_a

这是一个 PyTorch 中神经网络模型的类定义,它继承自 nn.Module 类,表明这个类是一个 PyTorch 模型。

分为几个部分展开描述:

(一)__init__()

    def __init__(self, dim_in):
        super(TrustGeo, self).__init__()
        self.dim_in = dim_in
        self.dim_z = dim_in + 2

        # TrustGeo
        self.att_attribute = SimpleAttention(temperature=self.dim_z ** 0.5,
                                             d_q_in=self.dim_in,
                                             d_k_in=self.dim_in,
                                             d_v_in=self.dim_in + 2,
                                             d_q_out=self.dim_z,
                                             d_k_out=self.dim_z,
                                             d_v_out=self.dim_z)


        # calculate A
      

文章来源:https://blog.csdn.net/sinat_41942180/article/details/134983178
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。