import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time
import os
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import ctypes
xgpu_proxy = ctypes.CDLL("../target/release/libxgpu_proxy.so")
torch.backends.cudnn.enabled = False
print("torch.backends.cudnn.enabled = False")
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
print(f"Using device: {device}")
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_sampler = DistributedSampler(train_dataset)
test_sampler = DistributedSampler(test_dataset, shuffle=False)
train_loader = DataLoader(train_dataset, batch_size=64, sampler=train_sampler)
test_loader = DataLoader(test_dataset, batch_size=64, sampler=test_sampler)
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.relu1 = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.relu2 = nn.ReLU()
self.fc1 = nn.Linear(32 * 7 * 7, 128)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.pool(self.relu1(self.conv1(x)))
x = self.pool(self.relu2(self.conv2(x)))
x = x.view(-1, 32 * 7 * 7)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return x
model = SimpleCNN().to(device)
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
def train(model, train_loader, criterion, optimizer, epochs=5):
model.train()
for epoch in range(epochs):
train_loader.sampler.set_epoch(epoch)
start_time = time.time()
running_loss = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
xgpu_proxy.minibatch_begin_hook()
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
if dist.get_rank() == 0 and batch_idx % 50 == 0:
print(f'\n\nEpoch {epoch+1}/{epochs} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}')
end_time = time.time()
epoch_time = end_time - start_time
if dist.get_rank() == 0:
print(f'\n\n\n\nEpoch {epoch+1} completed, Average Loss: {running_loss/len(train_loader):.6f}, Time: {epoch_time:.2f}s')
def test(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
if dist.get_rank() == 0:
print(f'Test Accuracy: {100 * correct / total:.2f}%')
if dist.get_rank() == 0:
print("cuda version:", torch.version.cuda)
train(model, train_loader, criterion, optimizer, epochs=1)
test(model, test_loader)
if dist.get_rank() == 0:
torch.save(model.state_dict(), 'mnist_cnn.pth')
print("Model saved as mnist_cnn.pth")
dist.destroy_process_group()