import numpy as np
import torch
import tensorflow as tf
from atk.configs.dataset_config import InputDataset
from atk.configs.results_config import TaskResult
from atk.tasks.api_execute import register
from atk.tasks.api_execute.base_api import BaseApi
@register("function_aclnnAddN")
class aclnnAddNExecutor(BaseApi):
def __call__(self, input_data:InputDataset, with_output:bool=False):
x = input_data.kwargs.get("input")
bf16 = tf.bfloat16.as_numpy_dtype
if x[0].dtype == torch.bfloat16:
x_casted = [item.to(torch.float32).numpy().astype(bf16) for item in x]
output_casted = tf.cast(tf.math.add_n(x_casted), tf.float32).numpy()
output = torch.from_numpy(output_casted).to(torch.bfloat16)
else:
output = torch.from_numpy(tf.math.add_n(x).numpy())
return output