import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision.models import ResNet50_Weights
class ResNet50:
def __init__(self, num_classes=1000, device=None):
if device is None:
if hasattr(torch, 'npu') and torch.npu.is_available():
self.device = torch.device("npu:0")
else:
raise RuntimeError(
"Current environment does not support NPU data collection. "
"Please check the versions of torch and torch_npu. "
"Recommended: torch >= 2.7.1, torch_npu >= 2.7.1, Python >= 3.7.5"
)
else:
self.device = torch.device(device)
torch.npu.set_device(self.device)
print(f"[INFO] Using device: {self.device}")
self.model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
if num_classes != 1000:
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
self.model = self.model.to(self.device)
def train(self, data_loader, epochs=1, lr=1e-4, freeze_backbone=False):
"""
Simple training function.
:param data_loader: torch.utils.data.DataLoader returning (images, labels)
:param epochs: Number of epochs to train for
:param lr: Learning rate
:param freeze_backbone: Whether to freeze the ResNet backbone, only training the classification head
"""
if freeze_backbone:
for param in self.model.parameters():
param.requires_grad = False
for param in self.model.fc.parameters():
param.requires_grad = True
params_to_optimize = [p for p in self.model.parameters() if p.requires_grad]
optimizer = optim.Adam(params_to_optimize, lr=lr)
criterion = nn.CrossEntropyLoss().to(self.device)
self.model.train()
for epoch in range(epochs):
total_loss = 0.0
for inputs, labels in data_loader:
inputs, labels = inputs.to(self.device), labels.to(self.device)
optimizer.zero_grad()
outputs = self.model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(data_loader)
print(f"[Epoch {epoch + 1}/{epochs}] Average Loss: {avg_loss:.4f}")