参考文章
- https://zhuanlan.zhihu.com/p/378474516
代码示例
def train(args):
datasets = pre_process()
data = datasets[0].to(device)
model = GAT(datasets.num_features, datasets.num_classes, [200, 100]).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
for epoch in range(args.epochs):
loop = tqdm(range(300), total=300)
for iteration in loop:
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
loop.set_description(f'Epoch [{epoch}/{args.epochs}]')
loop.set_postfix(loss=loss.item())
案例1
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as pyg_nn
from dataset import get_data
import matplotlib.pyplot as plt
import sklearn.metrics as metrics
from tqdm import tqdm
class GraphCNN(nn.Module):
def __init__(self, in_channels, out_channels, hidden_channels):
super(GraphCNN, self).__init__()
self.conv1 = pyg_nn.GCNConv(in_channels, hidden_channels)
self.conv2 = pyg_nn.GCNConv(hidden_channels, out_channels)
pass
def forward(self, data):
x = data.x
edge_index = data.edge_index
hid = self.conv1(x=x, edge_index=edge_index)
hid = F.relu(hid)
out = self.conv2(x=hid, edge_index=edge_index)
out = F.log_softmax(out, dim=1)
return out
def train():
cora_dataset = get_data()
model = GraphCNN(in_channels=cora_dataset.num_features, out_channels=cora_dataset.num_classes, hidden_channels=16)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
data = cora_dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
loss_fn = nn.CrossEntropyLoss()
for item in range(10):
epochs = []
losses = []
accs = []
loop = tqdm(range(100), total=100, desc='train')
for epoch in loop:
model.train()
optimizer.zero_grad()
out = model(data)
loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
epochs.append(epoch + 1)
losses.append(loss.item())
model.eval()
_, pred = model(data).max(dim=1)
acc = metrics.accuracy_score(y_true=data.y[data.test_mask].cpu(), y_pred=pred[data.test_mask].cpu())
accs.append(acc)
loop.set_description(f'Item [{item + 1}/{10}] Epoch [{epoch + 1}/{100}]')
loop.set_postfix(loss=loss.item(), acc=acc)
plt.plot(epochs, losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.plot(epochs, accs)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.show()
if __name__ == '__main__':
train()
效果如下 :
|