'''
demo.py
'''
import torch
import numpy as np
from apex import amp
import torch.distributed as dist
import torch.optim as optim
from collections import OrderedDict
def load_weights(model, weights_path):
print('Load weights from {}.'.format(weights_path))
weights = torch.load(weights_path)
weights = OrderedDict([[k.split('module.')[-1],
v.cpu()] for k, v in weights.items()])
try:
model.load_state_dict(weights)
except (KeyError, RuntimeError):
state = model.state_dict()
diff = list(set(state.keys()).difference(set(weights.keys())))
for d in diff:
print('Can not find weights [{}].'.format(d))
state.update(weights)
model.load_state_dict(state)
return model
def build_model():
from net.st_gcn import Model
torch.npu.set_device('npu:0')
model = Model(in_channels=3,
num_class=400,
edge_importance_weighting=True,
graph_args={'layout': "openpose",
'strategy': "spatial"})
model = model.npu()
optimizer = optim.SGD(
model.parameters(),
lr=0.1,
momentum=0.9)
model, optimizer = amp.initialize(
model, optimizer, opt_level="O2", loss_scale=1024)
model = load_weights(
model, './work_dir/recognition/kinetics_skeleton/ST_GCN/best_model_8p.pt')
model.eval()
return model
def get_raw_data():
inp = np.load('tools/raw_data.npy')
inputs = torch.from_numpy(inp).npu()
return inputs
def pre_process(raw_data):
return raw_data
def post_process(output_tensor):
return torch.argmax(output_tensor, 1)
if __name__ == '__main__':
raw_data = get_raw_data()
model = build_model()
input_tensor = pre_process(raw_data)
output_tensor = model(input_tensor)
result = post_process(output_tensor)
print(result)