GraphSAGE 到底在训练什么? 图上的Mini-Batch 是怎么训练的 ?
2023-12-13 03:37:58
1.?一个端到端的 同构图(Cora数据集)节点分类代码:
import argparse
import dgl
import dgl.nn as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import AddSelfLoop
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
class SAGE(nn.Module):
def __init__(self, in_size, hid_size, out_size):
super().__init__()
self.layers = nn.ModuleList()
# two-layer GraphSAGE-mean
self.layers.append(dglnn.SAGEConv(in_size, hid_size, "gcn"))
self.layers.append(dglnn.SAGEConv(hid_size, out_size, "gcn"))
self.dropout = nn.Dropout(0.5)
def forward(self, graph, x):
h = self.dropout(x)
for l, layer in enumerate(self.layers):
h = layer(graph, h)
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
return h
def evaluate(g, features, labels, mask, model):
model.eval()
with torch.no_grad():
logits = model(g, features)
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
def train(g, features, labels, masks, model):
# define train/val samples, loss function and optimizer
train_mask, val_mask = masks
loss_fcn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
# training loop
for epoch in range(200):
model.train()
logits = model(g, features)
loss = loss_fcn(logits[train_mask], labels[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = evaluate(g, features, labels, val_mask, model)
print(
"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
epoch, loss.item(), acc
)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GraphSAGE")
parser.add_argument(
"--dataset",
type=str,
default="cora",
help="Dataset name ('cora', 'citeseer', 'pubmed')",
)
parser.add_argument(
"--dt",
type=str,
default="float",
help="data type(float, bfloat16)",
)
args = parser.parse_args()
print(f"Training with DGL built-in GraphSage module")
# load and preprocess dataset
transform = (
AddSelfLoop()
) # by default, it will first remove self-loops to prevent duplication
if args.dataset == "cora":
data = CoraGraphDataset(transform=transform)
elif args.dataset == "citeseer":
data = CiteseerGraphDataset(transform=transform)
elif args.dataset == "pubmed":
data = PubmedGraphDataset(transform=transform)
else:
raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
g = g.int().to(device)
features = g.ndata["feat"]
labels = g.ndata["label"]
masks = g.ndata["train_mask"], g.ndata["val_mask"]
# create GraphSAGE model
in_size = features.shape[1]
out_size = data.num_classes
model = SAGE(in_size, 16, out_size).to(device)
# convert model and graph to bfloat16 if needed
if args.dt == "bfloat16":
g = dgl.to_bfloat16(g)
features = features.to(dtype=torch.bfloat16)
model = model.to(dtype=torch.bfloat16)
# model training
print("Training...")
train(g, features, labels, masks, model)
# test the model
print("Testing...")
acc = evaluate(g, features, labels, g.ndata["test_mask"], model)
print("Test accuracy {:.4f}".format(acc))
2.?GraphSAGE的实现 : SAGEConv 类:
我们先来介绍一下DGL对GraphSAGE这个模型的实现:SAGEConv() 在三方库的下述位置:
文章来源:https://blog.csdn.net/qq_41764621/article/details/134951209
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!