TrustGeo代码理解(二)test.py

2023-12-13 23:32:29

代码链接:https://github.com/ICDM-UESTC/TrustGeo

一、加载检查点(checkpoint)并进行测试

# -*- coding: utf-8 -*-

"""
    load checkpoint and then test
"""

该脚本的目的是加载之前训练过的模型的检查点,并使用测试数据集进行模型的性能评估。

二、导入各种模块和数据库

import torch.nn

from lib.utils import *
import argparse
import numpy as np
import random
from lib.model import *
import copy
import pandas as pd

脚本可能会涉及读取配置、加载模型、加载测试数据、执行测试、记录结果等操作。为了更详细的解释,需要查看 lib 文件夹中的 utilsmodel 模块的具体实现。

1、import torch.nn:导入 PyTorch 的神经网络模块。

2、from lib.utils import *从 lib.utils 模块中导入所有内容,这可能包括一些工具函数或辅助函数,用于该脚本或项目的其他部分。

3、import argparse:导入 argparse 模块用于解析命令行参数

4、import numpy as np:导入 NumPy 库,用于进行科学计算,特别是多维数组的处理。

5、import random导入 random 模块,用于生成伪随机数。

6、from lib.model import *从 lib.model 模块中导入所有内容,这可能包括定义神经网络模型的类等。

7、import copy:导入 copy 模块,用于复制对象,通常用于创建对象的深拷贝

8、import pandas as pd:导入 Pandas 库,用于数据处理和分析,通常用于处理表格型数据。

二、参数初始化(通过命令行参数)

parser = argparse.ArgumentParser()
# parameters of initializing
parser.add_argument('--seed', type=int, default=2022, help='manual seed')
parser.add_argument('--model_name', type=str, default='TrustGeo')
parser.add_argument('--dataset', type=str, default='New_York', choices=["Shanghai", "New_York", "Los_Angeles"],
                    help='which dataset to use')

这部分代码的目的是通过命令行参数设置一些初始化的参数,例如随机数种子、模型名称和数据集名称。这使得在运行脚本时可以通过命令行参数来指定这些参数的值。

1、parser = argparse.ArgumentParser():创建一个 argparse.ArgumentParser 对象,用于解析命令行参数。

2、# parameters of initializing:注释,表示接下来是初始化参数的部分。

3、parser.add_argument('--seed', type=int, default=2022, help='manual seed'):添加一个命令行参数,名称为 '--seed',表示随机数种子,类型为整数,默认值为 2022,help 参数是在命令行中输入 --help 时显示的帮助信息。

4、parser.add_argument('--model_name', type=str, default='TrustGeo'):添加一个命令行参数,名称为 '--model_name',表示模型的名称,类型为字符串,默认值为 'TrustGeo'。

5、parser.add_argument('--dataset', type=str, default='New_York', choices=["Shanghai", "New_York", "Los_Angeles"], help='which dataset to use'):添加一个命令行参数,名称为 '--dataset',表示数据集的名称,类型为字符串,默认值为 'New_York',choices 参数指定了可选的值为 ["Shanghai", "New_York", "Los_Angeles"],用户只能从这三个值中选择。help 参数是在命令行中输入 --help 时显示的帮助信息。

三、训练过程参数设置

# parameters of training
parser.add_argument('--beta1', type=float, default=0.9)
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--lambda1', type=float, default=7e-3)
parser.add_argument('--lr', type=float, default=5e-3)
parser.add_argument('--harved_epoch', type=int, default=5) 
parser.add_argument('--early_stop_epoch', type=int, default=50)
parser.add_argument('--saved_epoch', type=int, default=200) 
parser.add_argument('--load_epoch', type=int, default=5) 

这部分代码的目的是设置一些训练过程中的超参数,例如优化器的动量参数、学习率、权重参数等。这些参数在训练过程中会影响模型的更新和收敛速度。

1、# parameters of training:注释,表示接下来是训练参数的部分。

2、parser.add_argument('--beta1', type=float, default=0.9):添加一个命令行参数,名称为 '--beta1',表示 Adam 优化器的第一个动量(momentum)参数,类型为浮点数,默认值为 0.9。

3、parser.add_argument('--beta2', type=float, default=0.999):添加一个命令行参数,名称为 '--beta2',表示 Adam 优化器的第二个动量参数,类型为浮点数,默认值为 0.999。

4、parser.add_argument('--lambda1', type=float, default=7e-3):添加一个命令行参数,名称为 '--lambda1',表示某个权重参数,类型为浮点数,默认值为 7e-3。

5、parser.add_argument('--lr', type=float, default=5e-3):添加一个命令行参数,名称为 '--lr',表示学习率,类型为浮点数,默认值为 5e-3。

6、parser.add_argument('--harved_epoch', type=int, default=5):添加一个命令行参数,名称为 '--harved_epoch',表示某个 epoch 的值,类型为整数,默认值为 5。

7、parser.add_argument('--early_stop_epoch', type=int, default=50):添加一个命令行参数,名称为 '--early_stop_epoch',表示早停(early stop)的 epoch 数,类型为整数,默认值为 50。

8、parser.add_argument('--saved_epoch', type=int, default=200):??添加一个命令行参数,名称为 '--saved_epoch',表示保存模型的 epoch 数,类型为整数,默认值为 200。

9、parser.add_argument('--load_epoch', type=int, default=5):从指定 epoch 的模型参数加载模型。(比model.py多出来的)

四、模型参数设置

# parameters of model
parser.add_argument('--dim_in', type=int, default=30, choices=[51, 30], help="51 if Shanghai / 30 else")

opt = parser.parse_args()
print("Learning rate: ", opt.lr)
print("Dataset: ", opt.dataset)

这部分代码的目的是解析命令行参数,并打印出学习率和数据集名称。--dim_in 参数用于指定输入维度,可以选择是 51 或者 30。

1、# parameters of model注释,表示接下来是训模型参数的部分。

2、parser.add_argument('--dim_in', type=int, default=30, choices=[51, 30], help="51 if Shanghai / 30 else"):添加一个命令行参数,名称为 '--dim_in',表示输入的维度,类型为整数,默认值为 30。choices 参数指定了可选的值为 [51, 30],用户只能从这两个值中选择。help 参数是在命令行中输入 --help 时显示的帮助信息。

3、opt = parser.parse_args():使用 argparse 解析命令行参数,将结果存储在 opt 变量中

4、print("Learning rate: ", opt.lr):打印学习率,即 opt 对象中的 lr 属性

5、print("Dataset: ", opt.dataset):打印数据集名称,即 opt 对象中的 dataset 属性

五、设置随机种子数

if opt.seed:
    print("Random Seed: ", opt.seed)
    random.seed(opt.seed)
    torch.manual_seed(opt.seed)
torch.set_printoptions(threshold=float('inf'))

这一部分的目的是确保在使用随机数的场景中,每次运行程序得到的随机结果是可复现的。通过设置相同的随机数种子,可以使得每次运行得到相同的随机数序列。

1、如果 opt 对象中的 seed 属性存在(不为 0 或 False 等假值),则执行以下操作:

  • 打印随机数种子的信息。
  • 使用 random 模块设置 Python 内建的随机数生成器的种子。
  • 使用 PyTorch 的 torch 模块设置随机数种子。

2、torch.set_printoptions(threshold=float('inf'))设置 PyTorch 的打印选项,将打印的元素数量限制设置为无穷大,即不限制打印的元素数量。这样可以确保在打印张量时,所有元素都会被打印出来,而不会被省略。

六、过滤所有警告信息

warnings.filterwarnings('ignore')

过滤掉所有警告信息,将警告信息忽略。这通常用于在代码中避免显示一些不影响程序执行的警告信息,以保持输出的清晰。在某些情况下,警告信息可能是有用的,但如果明确知道这些警告对程序执行没有影响,可以选择忽略它们。

七、动态选择运行环境

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("device:", device)
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

这部分代码的目的是根据硬件环境动态选择运行模型的设备,并选择相应的 PyTorch 张量类型。如果有可用的 GPU,就使用 GPU 运行模型和 GPU 张量类型;否则,使用 CPU 运行模型和 CPU 张量类型。

1、device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'):创建一个 PyTorch 设备对象,表示运行模型的设备。如果 CUDA 可用(即有可用的 GPU),则使用 'cuda:0' 表示第一个 GPU,否则使用 'cpu' 表示 CPU。

2、print("device:", device):打印设备的信息,即使用的是 GPU 还是 CPU。

3、cuda = True if torch.cuda.is_available() else False:根据 CUDA 是否可用设置一个布尔值,表示是否使用 GPU。如果 CUDA 可用,则 cuda 为 True,否则为 False。

4、Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor:根据上一步得到的 cuda 布尔值选择使用 GPU 还是 CPU 上的 PyTorch 张量类型。如果 cuda 为 True,则 Tensor 被设置为 torch.cuda.FloatTensor,表示在 GPU 上的浮点数张量类型,否则设置为 torch.FloatTensor,表示在 CPU 上的浮点数张量类型。

八、加载数据(训练和测试)

'''load data'''
train_data = np.load("./datasets/{}/Clustering_s1234_lm70_train.npz".format(opt.dataset),
                     allow_pickle=True)
test_data = np.load("./datasets/{}/Clustering_s1234_lm70_test.npz".format(opt.dataset),
                    allow_pickle=True)
train_data, test_data = train_data["data"], test_data["data"]
print("data loaded.")

这部分代码的目的是加载训练集和测试集的数据,数据文件的路径根据 opt.dataset 的值确定(见四、模型参数设置)。加载后,训练集和测试集的数据存储在 train_data 和 test_data 变量中。

1、'''load data''':这是一个注释,用于指示下面的代码块是用于加载数据的部分。

2、

train_data = np.load("./datasets/{}/Clustering_s1234_lm70_train.npz".format(opt.dataset), allow_pickle=True):使用 NumPy 的 load 函数加载训练数据。数据文件的路径由字符串格式化方法确定,其中 {} 部分会被 opt.dataset 替代,即数据集的名称。文件名的其余部分指定了数据集的具体文件名和路径。allow_pickle=True 表示允许加载包含 Python 对象的文件

3、test_data = np.load("./datasets/{}/Clustering_s1234_lm70_test.npz".format(opt.dataset),

allow_pickle=True):使用相同的方式加载测试数据,文件名中指定了测试集的文件名和路径。

4、train_data, test_data = train_data["data"], test_data["data"]:从加载的数据中提取具体的数据部分。这里假设加载的数据文件中包含一个名为 "data" 的键,其对应的值是实际的数据。train_data 和 test_data 分别表示训练集和测试集的数据。

5、print("data loaded."):打印提示信息,表示数据加载完成。

没有模型初始化、标准和优化器初始化

九、__main__函数

if __name__ == '__main__':
    train_data, test_data = get_data_generator(opt, train_data, test_data, normal=2)

    losses = [np.inf]

    checkpoint = torch.load(f"asset/model/{opt.dataset}_{opt.load_epoch}.pth")
    print(f"Load model asset/model/{opt.dataset}_{opt.load_epoch}.pth")
    model = eval("TrustGeo")(opt.dim_in)
    model.load_state_dict(checkpoint['model_state_dict'])
    if cuda:
        model.cuda() 

    # test
    total_mse, total_mae, test_num = 0, 0, 0
    dislist = []

    model.eval()
    distance_all = []  
    macs_list = []
    params_list = []

    with torch.no_grad():

        for i in range(len(test_data)):
            lm_X, lm_Y, tg_X, tg_Y, lm_delay, tg_delay, y_max, y_min = test_data[i]["lm_X"], test_data[i]["lm_Y"], \
                                                                           test_data[i][
                                                                               "tg_X"], test_data[i]["tg_Y"], \
                                                                           test_data[i][
                                                                               "lm_delay"], test_data[i]["tg_delay"], \
                                                                           test_data[i]["y_max"], test_data[i]["y_min"]

            y_pred_g, v_g, alpha_g, beta_g, y_pred_a, v_a, alpha_a, beta_a = model(Tensor(lm_X), Tensor(lm_Y), Tensor(tg_X),
                                                                                                                    Tensor(tg_Y), Tensor(lm_delay),Tensor(tg_delay))
            
            # fuse multi views
            y_pred_f, v_f, alpha_f, beta_f = fuse_nig(y_pred_g, v_g, alpha_g, beta_g, y_pred_a, v_a, alpha_a, beta_a)
               
            distance = dis_loss(Tensor(tg_Y), y_pred_f, y_max, y_min)
            for i in range(len(distance.cpu().detach().numpy())):
                dislist.append(distance.cpu().detach().numpy()[i])
                distance_all.append(distance.cpu().detach().numpy()[i])
                
            test_num += len(tg_Y)
            total_mse += (distance * distance).sum()
            total_mae += distance.sum()
            
        total_mse = total_mse / test_num
        total_mae = total_mae / test_num
    
        print("test: mse: {:.3f}  mae: {:.3f}".format(total_mse, total_mae))
        dislist_sorted = sorted(dislist)
        print('test median: {:.3f}'.format(dislist_sorted[int(len(dislist_sorted) / 2)]))

分为几个部分展开描述:

没有将配置文件写入日志文件、模型训练

(一)该脚本是否直接运行

if __name__ == '__main__':
    train_data, test_data = get_data_generator(opt, train_data, test_data, normal=2)

这是 Python 中常见的用法,表示如果该脚本是被直接运行而不是被导入为模块,那么以下的代码块将被执行。这通常用于将脚本既作为可执行程序又作为一个模块导入的情况。调用一个函数 get_data_generator。这里传递了一些参数,包括 opt、train_data、test_data 和 normal。get_data_generator的实现在utils.py文件

(二)初始化一些变量

losses = [np.inf]:初始化一些变量,包括保存训练过程中的损失值

(三)加载预训练模型的检查点文件

    checkpoint = torch.load(f"asset/model/{opt.dataset}_{opt.load_epoch}.pth")
    print(f"Load model asset/model/{opt.dataset}_{opt.load_epoch}.pth")
    model = eval("TrustGeo")(opt.dim_in)
    model.load_state_dict(checkpoint['model_state_dict'])
    if cuda:
        model.cuda() 

这段代码的目的是加载预训练模型的检查点文件,并将其参数应用于新创建的模型实例。这样,就可以在之后的代码中使用这个加载的模型进行推理或者继续训练。

1、checkpoint = torch.load(f"asset/model/{opt.dataset}_{opt.load_epoch}.pth"): 这一行代码使用 torch.load 函数加载保存在文件系统中的 PyTorch 模型检查点。opt.dataset 表示数据集名称,opt.load_epoch 表示加载的检查点的训练时期。加载后,checkpoint 将包含保存在检查点文件中的各种信息,例如模型的状态字典、优化器状态等。

2、print(f"Load model asset/model/{opt.dataset}_{opt.load_epoch}.pth"): 这一行代码简单地打印加载的模型文件的路径,以提供用户一些反馈。

3、model = eval("TrustGeo")(opt.dim_in): 这一行代码通过 eval 函数创建了一个新的模型实例。"TrustGeo" 是模型的类名,通过字符串动态地创建模型的实例。opt.dim_in 是模型的输入维度。这是因为在训练和保存模型时,可能没有保存模型的实例,而只保存了模型的参数,因此需要重新创建模型实例。

4、model.load_state_dict(checkpoint['model_state_dict']): 这一行代码将预训练模型的状态字典加载到新创建的模型实例中。checkpoint['model_state_dict'] 包含了预训练模型的参数。

5、这一行代码将模型移动到 GPU 上,如果 GPU 可用的话。这确保模型在进行推理或训练时利用 GPU 资源。

?(四)模型测试

这段代码主要用于在测试集上评估模型的性能,并且记录了一些性能指标和其他信息。

1、# test:这是一个注释,用于指示下面的代码块是用于测试模型的部分。

2、初始化一些变量(与训练时初始化的变量不一样)

    total_mse, total_mae, test_num = 0, 0, 0
    dislist = []

    model.eval()
    distance_all = []
    macs_list = []
    params_list = []

这段代码的目的是在测试阶段对每个测试样本进行模型预测,计算损失值,累加总体均方误差和平均绝对误差,并记录每个测试样本的具体损失值以及一些与模型复杂度相关的信息。这些信息可以用于评估模型在测试集上的性能,并分析模型的计算资源使用情况。

total_mse, total_mae, test_num = 0, 0, 0?初始化三个变量,分别用于累加均方误差(total_mse)、平均绝对误差(total_mae)以及测试样本的数量(test_num)。

dislist = []: 初始化一个空列表 dislist,用于存储每个测试样本的损失值。

model.eval(): 将模型设置为评估模式,这是因为在测试阶段不需要进行梯度计算,而且可能有一些与训练不同的行为,例如对于某些层的批量标准化。

distance_all = []: 初始化一个空列表 distance_all,似乎是用于存储所有测试样本的损失值。

macs_list = []params_list = []: 初始化两个空列表,用于存储每个测试样本对应的模型的浮点运算量(MACs)和参数数量。这些信息通常用于模型的计算资源分析。(比model.py多出来的)

3、with torch.no_grad():进入无梯度计算的上下文,即下面的计算不会影响梯度。

(1)对于每个测试样本进行以下操作

        for i in range(len(test_data)):
            lm_X, lm_Y, tg_X, tg_Y, lm_delay, tg_delay, y_max, y_min = test_data[i]["lm_X"], test_data[i]["lm_Y"], \
                                                                           test_data[i][
                                                                               "tg_X"], test_data[i]["tg_Y"], \
                                                                           test_data[i][
                                                                               "lm_delay"], test_data[i]["tg_delay"], \
                                                                           test_data[i]["y_max"], test_data[i]["y_min"]

            y_pred_g, v_g, alpha_g, beta_g, y_pred_a, v_a, alpha_a, beta_a = model(Tensor(lm_X), Tensor(lm_Y), Tensor(tg_X),
                                                                                                                    Tensor(tg_Y), Tensor(lm_delay),Tensor(tg_delay))
            
            # fuse multi views
            y_pred_f, v_f, alpha_f, beta_f = fuse_nig(y_pred_g, v_g, alpha_g, beta_g, y_pred_a, v_a, alpha_a, beta_a)
               
            distance = dis_loss(Tensor(tg_Y), y_pred_f, y_max, y_min)
            for i in range(len(distance.cpu().detach().numpy())):
                dislist.append(distance.cpu().detach().numpy()[i])
                distance_all.append(distance.cpu().detach().numpy()[i])
                
            test_num += len(tg_Y)
            total_mse += (distance * distance).sum()
            total_mae += distance.sum()

①从测试数据中获取需要的输入和标签

lm_X, lm_Y, tg_X, tg_Y, lm_delay, tg_delay, y_max, y_min = test_data[i]["lm_X"], test_data[i]["lm_Y"], \
                                                                           test_data[i][
                                                                               "tg_X"], test_data[i]["tg_Y"], \
                                                                           test_data[i][
                                                                               "lm_delay"], test_data[i]["tg_delay"], \
                                                                           test_data[i]["y_max"], test_data[i]["y_min"]

②使用模型进行前向传播,得到模型的输出。?

y_pred_g, v_g, alpha_g, beta_g, y_pred_a, v_a, alpha_a, beta_a = model(Tensor(lm_X), Tensor(lm_Y), Tensor(tg_X),Tensor(tg_Y), Tensor(lm_delay),Tensor(tg_delay))
                                                                                

model的实现在model.py文件?(class TrustGeo)

③融合多视图输出,得到最终的输出。与训练时一样

# fuse multi views
y_pred_f, v_f, alpha_f, beta_f = fuse_nig(y_pred_g, v_g, alpha_g, beta_g, y_pred_a, v_a, alpha_a, beta_a)

fuse_nig的实现在utils.py文件?

④计算距离损失,并将每个样本的距离值记录到列表中。

distance = dis_loss(Tensor(tg_Y), y_pred_f, y_max, y_min)
for i in range(len(distance.cpu().detach().numpy())):
    dislist.append(distance.cpu().detach().numpy()[i])
    distance_all.append(distance.cpu().detach().numpy()[i])

dis_loss的实现在utils.py文件

distance 是一个 PyTorch 张量(Tensor),通过 cpu().detach().numpy() 转换为 NumPy 数组,以便后续处理。

dislist.append(distance.cpu().detach().numpy()[i]): 将 distance 中的每个元素添加到列表 dislist 中。这个列表用于收集每个样本的损失值。

distance_all.append(distance.cpu().detach().numpy()[i]): 将 distance 中的每个元素添加到另一个列表 distance_all 中。这个列表用于在整个测试集上收集损失值。

⑤更新总MSE、总MAE和总测试样本数(test_num)。total_mse和total_mae计算方式与训练时不一样

test_num += len(tg_Y)
total_mse += (distance * distance).sum()
total_mae += distance.sum()

?test_num += len(tg_Y): 这一行代码用于累加测试样本的数量,len(tg_Y) 表示当前批次的测试样本数量,test_num 是一个用于存储总测试样本数量的变量。

total_mse += (distance * distance).sum(): 这一行代码计算并累加每个测试样本的均方误差。distance 是前面计算的模型预测和实际地理位置之间的损失。通过 (distance * distance).sum() 计算了每个样本的平方损失,然后将这些平方损失相加,得到总的均方误差

total_mae += distance.sum(): 这一行代码计算并累加每个测试样本的平均绝对误差。distance 是前面计算的模型预测和实际地理位置之间的损失。通过 distance.sum() 计算了每个样本的绝对损失,然后将这些绝对损失相加,得到总的平均绝对误差。

这两个累加操作最终将整个测试集上的均方误差 (total_mse) 和平均绝对误差 (total_mae) 计算出来。这些指标用于评估模型在测试集上的性能,其中均方误差衡量了预测值与真实值之间的平方差,而平均绝对误差衡量了预测值与真实值之间的绝对差。

(2)计算平均MSE损失和平均 MAE。(总样本数在这里会使用到)

total_mse = total_mse / test_num
total_mae = total_mae / test_num

(3)?print("test: mse: {:.4f} mae: {:.4f}".format(total_mse, total_mae))打印平均MSE损失和平均 MAE。

(4)dislist_sorted = sorted(dislist):对距离列表进行排序?训练没有

(5)print('test median:', dislist_sorted[int(len(dislist_sorted) / 2)]):打印测试样本距离的中值??

没有保存检查点

没有对模型性能进行监控和控制

没有将当前 epoch 的 MAE 添加到损失列表中

没有学习率减半?

没有早停(early stopping)机制

文章来源:https://blog.csdn.net/sinat_41942180/article/details/134982175
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。