Professional Documents
Culture Documents
import torch.nn.functional as F
class GCN(torch.nn.Module):
super(GCN, self).__init__()
x = F.relu(self.conv1(x, edge_index))
x = self.conv2(x, edge_index)
return x
# Chuẩn bị dữ liệu
dataset = YourGraphDataset(root='your_dataset_root')
num_samples = len(dataset)
input_dim = dataset.num_node_features
hidden_dim = 64
model.train()
optimizer.zero_grad()
optimizer.step()
model.eval()
total_loss = 0
total_correct = 0
with torch.no_grad():
if hasattr(dataset, 'num_classes'):
pred = out.argmax(dim=1)
total_correct += pred.eq(data.y).sum().item()
else:
num_epochs = 10