# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
class AsyncEvaluator():
"""
Creates a threaded evaluator for a given device.
If device == None then the current active device is used
"""
def __init__(self, num_threads=1, device=None):
self.num_threads = num_threads
# self.pool = ThreadPoolExecutor(num_threads)
self.pool = ProcessPoolExecutor(num_threads)
self.events = {}
def __del__(self):
for t, e in self.events.items():
e.cancel()
# submit given function and its arguments with an
# associated tag for bookkkeeping
def submit_task(self, tag, fn, *args, **kwargs):
# launch work
e = self.pool.submit(fn, *args, **kwargs)
# record work
self.events[tag] = e
# check if a task has completed
def task_done(self, tag):
return self.events[tag].done()
# get the result of a task:
# Note: will block until completed
def task_result(self, tag):
return self.events[tag].result(timeout=None)
# Get all currently finished tasks in a dict of
# { tag : result }
def finished_tasks(self):
ret = {}
to_remove = []
# Check all existing tasks
for t in self.events.keys():
done = self.events[t].done()
if done:
ret[t] = self.task_result(t)
to_remove.append(t)
# As soon as a task is finished we want to remove it
for t in to_remove:
self.task_remove(t)
return ret
# remove a task from the outstanding list
# Note: will cancel task if not completed
def task_remove(self, tag):
done = self.events[tag].done()
# cancel task if necessary
if not done:
self.events[tag].cancel()
# remove the entry
del self.events[tag]
# return list of tags outstanding
def task_tags(self):
return self.events.keys()
# wait for everything to finish
def wait_all_tasks(self):
for t in self.events.keys():
y = self.task_result(t)
print('task {} finished'.format(t))