h_feats = 64
learn_iterations = 50
learn_rate = 0.01
model = EntityGraphModule(
dataset.graphs[0].ndata["feat"].shape[1],
dataset.graphs[0].edata["feat"].shape[1],
h_feats,
dataset.labels.max().item() + 1
)
optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)
for _ in range(learn_iterations):
for batched_graph, labels in train_dataloader:
pred = model(batched_graph, batched_graph.ndata["feat"].float(), batched_graph.edata["feat"].float())
loss = F.cross_entropy(pred, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
num_correct = 0
num_tests = 0
for batched_graph, labels in test_dataloader:
pred = model(batched_graph, batched_graph.ndata["feat"].float(), batched_graph.edata["feat"].float())
num_correct += (pred.argmax(1) == labels).sum().item()
num_tests += len(labels)
acc = num_correct / num_tests
print("Test accuracy:", acc)
Preview:
downloadDownload PNG
downloadDownload JPEG
downloadDownload SVG
Tip: You can change the style, width & colours of the snippet with the inspect tool before clicking Download!
Click to optimize width for Twitter