* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "AverageFunction.h"
AverageFunction::AverageFunction(int aggIdx, std::string inputType, int accIndexSum, int accIndexCount0, int valueIndex,
int filterIndex)
: aggIdx(aggIdx),
accIndexSum(accIndexSum),
accIndexCount0(accIndexCount0),
valueIndex(valueIndex),
filterIndex(filterIndex)
{
typeId = LogicalType::flinkTypeToOmniTypeId(inputType);
hasFilter = filterIndex != -1;
store = nullptr;
}
void AverageFunction::accumulate(RowData *accInput)
{
bool isFilter = true;
if (hasFilter) {
bool isFilterNull = accInput->isNullAt(filterIndex);
isFilter = !isFilterNull && *(accInput->getBool(filterIndex));
}
if (isFilter) {
bool isFieldNull = accInput->isNullAt(aggIdx);
long fieldValue;
switch (typeId) {
case DataTypeId::OMNI_INT: {
fieldValue = isFieldNull ? -1L : *accInput->getInt(aggIdx);
break;
}
case DataTypeId::OMNI_LONG: {
fieldValue = isFieldNull ? -1L : *accInput->getLong(aggIdx);
break;
}
default:
LOG("Data type is not supported.");
throw std::runtime_error("Data type is not supported.");
}
if (!isFieldNull) {
if (sumIsNull) {
sum = fieldValue;
sumIsNull = false;
} else {
sum += fieldValue;
}
if (count0IsNull) {
count0 = 1;
count0IsNull = false;
} else {
count0 += 1;
}
}
}
}
void AverageFunction::accumulate(omnistream::VectorBatch *input, const std::vector<int>& indices)
{
auto columnData = input->Get(aggIdx);
const bool hasFilterCol = hasFilter;
const auto filterData =
hasFilterCol ? reinterpret_cast<omniruntime::vec::Vector<bool> *>(input->Get(filterIndex)) : nullptr;
for (int rowIndex : indices) {
bool isFilter = true;
if (hasFilterCol) {
bool isFilterNull = filterData->IsNull(rowIndex);
isFilter = !isFilterNull && filterData->GetValue(rowIndex);
}
if (!isFilter) continue;
bool isFieldNull = columnData->IsNull(rowIndex);
long fieldValue;
switch (typeId) {
case DataTypeId::OMNI_INT: {
fieldValue = isFieldNull
? -1L
: dynamic_cast<omniruntime::vec::Vector<int> *>(columnData)->GetValue(rowIndex);
break;
}
case DataTypeId::OMNI_LONG: {
fieldValue = isFieldNull
? -1L
: dynamic_cast<omniruntime::vec::Vector<long> *>(columnData)->GetValue(rowIndex);
break;
}
default:
LOG("Data type is not supported.");
throw std::runtime_error("Data type is not supported.");
}
if (!isFieldNull) {
if (sumIsNull) {
sum = fieldValue;
sumIsNull = false;
} else {
sum += fieldValue;
}
if (count0IsNull) {
count0 = 1;
count0IsNull = false;
} else {
count0 += 1;
}
}
}
}
void AverageFunction::setAccumulators(RowData *_acc)
{
sumIsNull = _acc->isNullAt(accIndexSum);
sum = sumIsNull ? 0L : *_acc->getLong(accIndexSum);
count0IsNull = _acc->isNullAt(accIndexCount0);
count0 = count0IsNull ? 0L : *_acc->getLong(accIndexCount0);
}
void AverageFunction::resetAccumulators()
{
sum = 0;
sumIsNull = false;
count0 = 0;
count0IsNull = false;
}
void AverageFunction::open(StateDataViewStore *store)
{
this->store = store;
}
void AverageFunction::createAccumulators(BinaryRowData* accumulators)
{
accumulators->setLong(accIndexSum, 0L);
accumulators->setLong(accIndexCount0, 0L);
}
void AverageFunction::retract(RowData *retractInput)
{
bool isFieldNull = retractInput->isNullAt(aggIdx);
long fieldValue;
switch (typeId) {
case DataTypeId::OMNI_INT: {
fieldValue = isFieldNull ? -1L : *retractInput->getInt(aggIdx);
break;
}
case DataTypeId::OMNI_LONG: {
fieldValue = isFieldNull ? -1L : *retractInput->getLong(aggIdx);
break;
}
default:
LOG("Data type is not supported.");
throw std::runtime_error("Data type is not supported.");
}
if (!isFieldNull) {
sum = sumIsNull ? sum : sum - fieldValue;
count0 = count0IsNull ? count0 : count0 - 1;
}
}
void AverageFunction::retract(omnistream::VectorBatch* input, const std::vector<int>& indices)
{
auto columnData = input->Get(aggIdx);
for (int rowIndex : indices) {
bool isFieldNull = columnData->IsNull(rowIndex);
long fieldValue;
switch (typeId) {
case DataTypeId::OMNI_INT: {
fieldValue = isFieldNull
? -1L
: dynamic_cast<omniruntime::vec::Vector<int>*>(columnData)->GetValue(rowIndex);
break;
}
case DataTypeId::OMNI_LONG: {
fieldValue = isFieldNull
? -1L
: dynamic_cast<omniruntime::vec::Vector<long>*>(columnData)->GetValue(rowIndex);
break;
}
default:
LOG("Data type is not supported.");
throw std::runtime_error("Data type is not supported.");
}
if (!isFieldNull) {
sum = sumIsNull ? sum : sum - fieldValue;
count0 = count0IsNull ? count0 : count0 - 1;
}
}
}
void AverageFunction::merge(RowData *otherAcc)
{
throw std::runtime_error("This function does not require the merge method, but the merge method is called.");
}
void AverageFunction::getAccumulators(BinaryRowData *acc)
{
if (sumIsNull) {
acc->setNullAt(accIndexSum);
} else {
acc->setLong(accIndexSum, sum);
}
if (count0IsNull) {
acc->setNullAt(accIndexCount0);
} else {
acc->setLong(accIndexCount0, count0);
}
}
void AverageFunction::getValue(BinaryRowData *aggValue)
{
if (count0IsNull || count0 == 0 || sumIsNull) {
aggValue->setNullAt(valueIndex);
} else {
long average = sum / count0;
aggValue->setLong(valueIndex, average);
}
}
bool AverageFunction::equaliser(BinaryRowData *r1, BinaryRowData *r2)
{
if (r1->isNullAt(valueIndex) || r2->isNullAt(valueIndex)) {
return false;
}
bool isEqual = false;
switch (typeId) {
case DataTypeId::OMNI_INT: {
isEqual = *r1->getInt(valueIndex) == *r2->getInt(valueIndex);
break;
}
case DataTypeId::OMNI_LONG: {
isEqual = *r1->getLong(valueIndex) == *r2->getLong(valueIndex);
break;
}
default:
LOG("Data type is not supported.");
throw std::runtime_error("Data type is not supported.");
}
return isEqual;
}