You are on page 1of 4

import torch

import torch.nn.functional as F

from torch_geometric.data import DataLoader

from torch_geometric.datasets import YourGraphDataset

from torch_geometric.nn import GCNConv

from torch.utils.data import random_split

# Định nghĩa mô hình GCN

class GCN(torch.nn.Module):

def __init__(self, input_dim, hidden_dim, output_dim):

super(GCN, self).__init__()

self.conv1 = GCNConv(input_dim, hidden_dim)

self.conv2 = GCNConv(hidden_dim, output_dim)

def forward(self, x, edge_index):

x = F.relu(self.conv1(x, edge_index))

x = self.conv2(x, edge_index)

return x

# Chuẩn bị dữ liệu

dataset = YourGraphDataset(root='your_dataset_root')

num_samples = len(dataset)

train_size = int(0.7 * num_samples)

test_size = val_size = int(0.15 * num_samples)


# Chia dữ liệu thành các tập train, test và validation

train_data, rest_data = random_split(dataset, [train_size, num_samples - train_size])

test_data, val_data = random_split(rest_data, [test_size, val_size])

# Tạo DataLoader cho các tập dữ liệu

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)

val_loader = DataLoader(val_data, batch_size=64, shuffle=False)

test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

# Khởi tạo mô hình GCN

input_dim = dataset.num_node_features

output_dim = dataset.num_classes if hasattr(dataset, 'num_classes') else 1

hidden_dim = 64

model = GCN(input_dim, hidden_dim, output_dim)

# Huấn luyện mô hình

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

criterion = torch.nn.CrossEntropyLoss() if hasattr(dataset, 'num_classes') else


torch.nn.BCEWithLogitsLoss()

def train_epoch(model, loader, optimizer, criterion):

model.train()

for data in loader:

optimizer.zero_grad()

out = model(data.x, data.edge_index)

loss = criterion(out, data.y)


loss.backward()

optimizer.step()

def evaluate_model(model, loader, criterion):

model.eval()

total_loss = 0

total_correct = 0

with torch.no_grad():

for data in loader:

out = model(data.x, data.edge_index)

if hasattr(dataset, 'num_classes'):

pred = out.argmax(dim=1)

total_correct += pred.eq(data.y).sum().item()

else:

pred = (out > 0).float()

total_correct += (pred == data.y).sum().item()

total_loss += criterion(out, data.y).item() * data.num_graphs

return total_loss / len(loader.dataset), total_correct / len(loader.dataset)

num_epochs = 10

for epoch in range(num_epochs):

train_epoch(model, train_loader, optimizer, criterion)

val_loss, val_acc = evaluate_model(model, val_loader, criterion)

print(f'Epoch {epoch + 1}, Validation Loss: {val_loss}, Validation Accuracy: {val_acc}')


# Đánh giá mô hình trên tập test

test_loss, test_acc = evaluate_model(model, test_loader, criterion)

print(f'Test Loss: {test_loss}, Test Accuracy: {test_acc}')

You might also like