* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "SumFunction.h"
SumFunction::SumFunction(int aggIdx, std::string inputType, int accIndex, int valueIndex, int filterIndex)
: aggIdx(aggIdx),
inputType(inputType),
accIndex(accIndex),
valueIndex(valueIndex),
filterIndex(filterIndex)
{
hasFilter = filterIndex != -1;
store = nullptr;
accIndexCount0 = -1;
}
void SumFunction::accumulate(RowData* accInput)
{
bool shouldDoAccumulate = true;
if (hasFilter) {
bool isFilterNull = accInput->isNullAt(filterIndex);
shouldDoAccumulate = !isFilterNull && *(accInput->getBool(filterIndex));
}
if (shouldDoAccumulate) {
auto typeId = LogicalType::flinkTypeToOmniTypeId(inputType);
bool isFieldNull = accInput->isNullAt(aggIdx);
long fieldValue;
switch (typeId) {
case DataTypeId::OMNI_INT: {
fieldValue = isFieldNull ? -1L : *accInput->getInt(aggIdx);
break;
}
case DataTypeId::OMNI_LONG: {
fieldValue = isFieldNull ? -1L : *accInput->getLong(aggIdx);
break;
}
default: LOG("Data type is not supported."); throw std::runtime_error("Data type is not supported.");
}
if (!isFieldNull) {
if (sumIsNull) {
sum = fieldValue;
sumIsNull = false;
} else {
sum += fieldValue;
}
if (consumeRetraction) {
count0IsNull = false;
count0 += 1;
}
}
}
}
void SumFunction::accumulate(omnistream::VectorBatch* input, const std::vector<int>& indices)
{
auto columnData = input->Get(aggIdx);
const bool hasFilterCol = hasFilter;
const auto filterData =
hasFilterCol ? reinterpret_cast<omniruntime::vec::Vector<bool>*>(input->Get(filterIndex)) : nullptr;
for (int rowIndex : indices) {
bool isFilter = true;
if (hasFilterCol) {
bool isFilterNull = filterData->IsNull(rowIndex);
isFilter = !isFilterNull && filterData->GetValue(rowIndex);
}
if (!isFilter) continue;
bool isFieldNull = columnData->IsNull(rowIndex);
long fieldValue;
auto typeId = LogicalType::flinkTypeToOmniTypeId(inputType);
switch (typeId) {
case DataTypeId::OMNI_INT: {
fieldValue =
isFieldNull ? -1L : dynamic_cast<omniruntime::vec::Vector<int>*>(columnData)->GetValue(rowIndex);
break;
}
case DataTypeId::OMNI_LONG: {
fieldValue =
isFieldNull ? -1L : dynamic_cast<omniruntime::vec::Vector<long>*>(columnData)->GetValue(rowIndex);
break;
}
default: LOG("Data type is not supported."); throw std::runtime_error("Data type is not supported.");
}
if (!isFieldNull) {
if (sumIsNull) {
sum = fieldValue;
sumIsNull = false;
} else {
sum += fieldValue;
}
}
}
}
void SumFunction::setAccumulators(RowData* _acc)
{
sumIsNull = _acc->isNullAt(accIndex);
sum = sumIsNull ? 0L : *_acc->getLong(accIndex);
if (consumeRetraction) {
count0 = count0IsNull ? 0L : *_acc->getLong(accIndexCount0);
}
}
void SumFunction::resetAccumulators()
{
sum = 0;
sumIsNull = true;
if (consumeRetraction) {
count0 = 0;
count0IsNull = true;
}
}
void SumFunction::createAccumulators(BinaryRowData* accumulators)
{
accumulators->setNullAt(accIndex);
if (consumeRetraction) {
accumulators->setNullAt(accIndexCount0);
}
}
void SumFunction::open(StateDataViewStore* store)
{
this->store = store;
}
void SumFunction::retract(RowData* retractInput)
{
auto typeId = LogicalType::flinkTypeToOmniTypeId(inputType);
bool isFieldNull = retractInput->isNullAt(aggIdx);
long fieldValue;
switch (typeId) {
case DataTypeId::OMNI_INT: {
fieldValue = isFieldNull ? -1L : *retractInput->getInt(aggIdx);
break;
}
case DataTypeId::OMNI_LONG: {
fieldValue = isFieldNull ? -1L : *retractInput->getLong(aggIdx);
break;
}
default: LOG("Data type is not supported."); throw std::runtime_error("Data type is not supported.");
}
if (!isFieldNull) {
sum = sumIsNull ? sum : sum - fieldValue;
if (consumeRetraction) {
count0 -= 1;
}
}
}
void SumFunction::retract(omnistream::VectorBatch* input, const std::vector<int>& indices)
{
auto columnData = input->Get(aggIdx);
auto typeId = LogicalType::flinkTypeToOmniTypeId(inputType);
for (int rowIndex : indices) {
bool isFieldNull = columnData->IsNull(rowIndex);
long fieldValue;
switch (typeId) {
case DataTypeId::OMNI_INT: {
fieldValue =
isFieldNull ? -1L : dynamic_cast<omniruntime::vec::Vector<int>*>(columnData)->GetValue(rowIndex);
break;
}
case DataTypeId::OMNI_LONG: {
fieldValue =
isFieldNull ? -1L : dynamic_cast<omniruntime::vec::Vector<long>*>(columnData)->GetValue(rowIndex);
break;
}
default: LOG("Data type is not supported."); throw std::runtime_error("Data type is not supported.");
}
if (!isFieldNull) {
sum = sumIsNull ? sum : sum - fieldValue;
}
}
}
void SumFunction::merge(RowData* otherAcc)
{
throw std::runtime_error("This function does not require the merge method, but the merge method is called.");
}
void SumFunction::getAccumulators(BinaryRowData* acc)
{
if (sumIsNull) {
acc->setNullAt(accIndex);
} else {
acc->setLong(accIndex, sum);
}
if (consumeRetraction) {
if (count0IsNull) {
acc->setNullAt(accIndexCount0);
} else {
acc->setLong(accIndexCount0, count0);
}
}
}
void SumFunction::getValue(BinaryRowData* aggValue)
{
if (sumIsNull) {
aggValue->setNullAt(valueIndex);
} else {
aggValue->setLong(valueIndex, sum);
}
}
bool SumFunction::equaliser(BinaryRowData* r1, BinaryRowData* r2)
{
if (r1->isNullAt(valueIndex) || r2->isNullAt(valueIndex)) {
return false;
}
bool isEqual = false;
auto typeId = LogicalType::flinkTypeToOmniTypeId(inputType);
switch (typeId) {
case DataTypeId::OMNI_INT: {
isEqual = *r1->getInt(valueIndex) == *r2->getInt(valueIndex);
break;
}
case DataTypeId::OMNI_LONG: {
isEqual = *r1->getLong(valueIndex) == *r2->getLong(valueIndex);
break;
}
default: LOG("Data type is not supported."); throw std::runtime_error("Data type is not supported.");
}
return isEqual;
}
void SumFunction::setRetraction(int accIndexCount0Index)
{
this->accIndexCount0 = accIndexCount0Index;
if (accIndexCount0Index >= 0) {
consumeRetraction = true;
}
}