import torch
import torch
from atk.configs.dataset_config import InputDataset
from atk.configs.results_config import TaskResult
from atk.tasks.api_execute import register
from atk.tasks.api_execute.base_api import BaseApi
from atk.tasks.dataset.base_dataset import OpsDataset
from atk.tasks.api_execute.aclnn_base_api import AclnnBaseApi
@register("ascend_function_aclnn_stft")
class AclnnStftApi(AclnnBaseApi):
def init_by_input_data(self, input_data):
"""参数处理"""
input_args, output_packages = super().init_by_input_data(input_data)
input_args.insert(2, input_args[-1])
input_args.pop()
output_packages[:] = [input_args[2]]
return input_args, output_packages
@register("ascend_function_torch_stft")
class TorchStftApi(BaseApi):
def __init__(self, task_result: TaskResult):
super(TorchStftApi, self).__init__(task_result)
OpsDataset.seed_everything()
self.change_flag = None
def __call__(self, input_data: InputDataset, with_output: bool = False):
x = input_data.kwargs["self"]
window = input_data.kwargs["windowOptional"]
nfft = input_data.kwargs["nFft"]
hop_length = input_data.kwargs["hopLength"]
win_length = input_data.kwargs["winLength"]
normalized = input_data.kwargs["normalized"]
onesided = input_data.kwargs["onesided"]
return_complex = input_data.kwargs["returnComplex"]
if self.output is None:
output = torch.stft(x,nfft,hop_length=hop_length,win_length=win_length,window=window,center=False,normalized=normalized,onesided=onesided,return_complex=return_complex)
else:
if isinstance(self.output, int):
output = input_data.args[self.output]
elif isinstance(self.output, str):
output = input_data.kwargs[self.output]
else:
raise ValueError(
f"self.output {self.output} value is " f"error"
)
return output