import os
import apex
import torch
import torch_npu
from apex import amp
from torch import nn
from ascend_utils.common.utils import count_parameters
from modelslim.pytorch import sparse
device = torch.device("npu:{}".format(os.getenv('DEVICE_ID', 0)))
torch.npu.set_device(device)
model = nn.Sequential(
nn.Conv2d(3, 32, 1, 1, bias=False),
nn.Sequential(nn.Conv2d(32, 64, 1, 1, bias=False), nn.BatchNorm2d(64), nn.Conv2d(64, 32, 1, 1, bias=False)),
nn.Sequential(nn.Conv2d(32, 64, 1, 1, bias=False), nn.BatchNorm2d(64), nn.Conv2d(64, 32, 1, 1, bias=False)),
nn.Sequential(nn.Conv2d(32, 64, 1, 1, bias=False), nn.BatchNorm2d(64), nn.Conv2d(64, 32, 1, 1, bias=False)),
nn.Sequential(nn.Conv2d(32, 64, 1, 1, bias=False), nn.BatchNorm2d(64), nn.Conv2d(64, 32, 1, 1, bias=False)),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(32, 10, bias=False),
).to(device)
optimizer = apex.optimizers.NpuFusedSGD(model.parameters(), lr=0.1)
steps_per_epoch, epochs_each_stage = 10, [2, 3, 1]
oring_model_params = count_parameters(model)
model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O2", combine_grad=False)
model = sparse.sparse_model_width(
model, optimizer, steps_per_epoch=steps_per_epoch, epochs_each_stage=epochs_each_stage
)
for _ in range(steps_per_epoch * sum(epochs_each_stage)):
optimizer.zero_grad()
output = model(torch.ones([1, 3, 32, 32]).npu())
loss = torch.mean(output)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()