/*
 * Copyright (c) 2024 Huawei Technologies Co.,Ltd.
 *
 * openGauss is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *
 *          http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 * -------------------------------------------------------------------------
 *
 * sparsevec.cpp
 *
 * IDENTIFICATION
 *        src/common/backend/utils/adt/sparsevec.cpp
 *
 * -------------------------------------------------------------------------
 */
#include "postgres.h"

#include <climits>
#include <cmath>

#include "fmgr.h"
#include "access/datavec/halfutils.h"
#include "access/datavec/halfvec.h"
#include "libpq/pqformat.h"
#include "utils/shortest_dec.h"
#include "access/datavec/sparsevec.h"
#include "utils/array.h"
#include "utils/builtins.h"
#include "access/datavec/vector.h"

#include <cfloat>
#include "utils/builtins.h"

typedef struct SparseInputElement {
    int32 index;
    float value;
} SparseInputElement;

/*
 * Ensure same dimensions
 */
static inline void CheckDims(SparseVector *a, SparseVector *b)
{
    if (a->dim != b->dim) {
        ereport(ERROR,
                (errcode(ERRCODE_DATA_EXCEPTION), errmsg("different sparsevec dimensions %d and %d", a->dim, b->dim)));
    }
}

/*
 * Ensure expected dimensions
 */
static inline void CheckExpectedDim(int32 typmod, int dim)
{
    if (typmod != -1 && typmod != dim) {
        ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("expected %d dimensions, not %d", typmod, dim)));
    }
}

/*
 * Ensure valid dimensions
 */
static inline void CheckDim(int dim)
{
    if (dim < 1) {
        ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("sparsevec must have at least 1 dimension")));
    }

    if (dim > SPARSEVEC_MAX_DIM) {
        ereport(ERROR, (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
                        errmsg("sparsevec cannot have more than %d dimensions", SPARSEVEC_MAX_DIM)));
    }
}

/*
 * Ensure valid nnz
 */
static inline void CheckNnz(int nnz, int dim)
{
    if (nnz < 0)
        ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("sparsevec cannot have negative number of elements")));

    if (nnz > SPARSEVEC_MAX_NNZ)
        ereport(ERROR, (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
                        errmsg("sparsevec cannot have more than %d non-zero elements", SPARSEVEC_MAX_NNZ)));

    if (nnz > dim)
        ereport(ERROR, (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
                        errmsg("sparsevec cannot have more elements than dimensions")));
}

/*
 * Ensure valid index
 */
static inline void CheckIndex(int32 *indices, int i, int dim)
{
    int32 index = indices[i];

    if (index < 0 || index >= dim) {
        ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("sparsevec index out of bounds")));
    }

    if (i > 0) {
        if (index < indices[i - 1])
            ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("sparsevec indices must be in ascending order")));

        if (index == indices[i - 1])
            ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("sparsevec indices must not contain duplicates")));
    }
}

/*
 * Ensure finite element
 */
static inline void CheckElement(float value)
{
    if (isnan(value))
        ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("NaN not allowed in sparsevec")));

    if (isinf(value))
        ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("infinite value not allowed in sparsevec")));
}

/*
 * Allocate and initialize a new sparse vector
 */
SparseVector *InitSparseVector(int dim, int nnz)
{
    SparseVector *result;
    int size;

    size = SPARSEVEC_SIZE(nnz);
    result = (SparseVector *)palloc0(size);
    SET_VARSIZE(result, size);
    result->dim = dim;
    result->nnz = nnz;

    return result;
}

/*
 * Check for whitespace, since array_isspace() is static
 */
static inline bool SparsevecIsspace(char ch)
{
    if (ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\v' || ch == '\f') {
        return true;
    }
    return false;
}

/*
 * Compare indices
 */
static int CompareIndices(const void *a, const void *b)
{
    if (((SparseInputElement *)a)->index < ((SparseInputElement *)b)->index) {
        return -1;
    }

    if (((SparseInputElement *)a)->index > ((SparseInputElement *)b)->index) {
        return 1;
    }

    return 0;
}

/*
 * Convert textual representation to internal representation
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_in);
Datum sparsevec_in(PG_FUNCTION_ARGS)
{
    char *lit = PG_GETARG_CSTRING(0);
    int32 typmod = PG_GETARG_INT32(2);
    long dim;
    char *pt = lit;
    char *stringEnd;
    SparseVector *result;
    float *rvalues;
    SparseInputElement *elements;
    int maxNnz;
    int nnz = 0;

    maxNnz = 1;
    while (*pt != '\0') {
        if (*pt == ',') {
            maxNnz++;
        }

        pt++;
    }

    elements = (SparseInputElement *)palloc(Min(maxNnz, SPARSEVEC_MAX_NNZ) * sizeof(SparseInputElement));

    pt = lit;

    while (SparsevecIsspace(*pt)) {
        pt++;
    }

    if (*pt != '{')
        ereport(ERROR, (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
                        errmsg("invalid input syntax for type sparsevec: \"%s\"", lit),
                        errdetail("Sparsevec contents must start with \"{\".")));

    pt++;

    while (SparsevecIsspace(*pt)) {
        pt++;
    }

    if (*pt == '}') {
        pt++;
    } else {
        for (;;) {
            long index;
            float value;

            if (nnz == maxNnz) {
                ereport(ERROR,
                        (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
                        errmsg("the current nnz value of %d ran out of buffer: \"%s\"", nnz, lit)));
            }
            while (SparsevecIsspace(*pt)) {
                pt++;
            }

            /* Check for empty string like float4in */
            if (*pt == '\0')
                ereport(ERROR, (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
                                errmsg("invalid input syntax for type sparsevec: \"%s\"", lit)));

            /* Use similar logic as int2vectorin */
            index = strtol(pt, &stringEnd, 10);

            if (stringEnd == pt)
                ereport(ERROR, (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
                                errmsg("invalid input syntax for type sparsevec: \"%s\"", lit)));

            /* Keep in int range for correct error message later */
            if (index > INT_MAX) {
                index = INT_MAX;
            } else if (index < INT_MIN + 1) {
                index = INT_MIN + 1;
            }

            pt = stringEnd;

            while (SparsevecIsspace(*pt)) {
                pt++;
            }

            if (*pt != ':')
                ereport(ERROR, (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
                                errmsg("invalid input syntax for type sparsevec: \"%s\"", lit)));

            pt++;

            while (SparsevecIsspace(*pt)) {
                pt++;
            }

            errno = 0;

            /* Use strtof like float4in to avoid a double-rounding problem */
            /* Postgres sets LC_NUMERIC to C on startup */
            value = strtof(pt, &stringEnd);

            if (stringEnd == pt)
                ereport(ERROR, (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
                                errmsg("invalid input syntax for type sparsevec: \"%s\"", lit)));

            /* Check for range error like float4in */
            if (errno == ERANGE && (value == 0 || isinf(value)))
                ereport(ERROR, (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
                                errmsg("\"%s\" is out of range for type sparsevec", pnstrdup(pt, stringEnd - pt))));

            CheckElement(value);

            /* Do not store zero values */
            if (value != 0) {
                /* Convert 1-based numbering (SQL) to 0-based (C) */
                elements[nnz].index = index - 1;
                elements[nnz].value = value;
                nnz++;
            }

            pt = stringEnd;

            while (SparsevecIsspace(*pt)) {
                pt++;
            }

            if (*pt == ',') {
                pt++;
            } else if (*pt == '}') {
                pt++;
                break;
            } else {
                ereport(ERROR, (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
                                errmsg("invalid input syntax for type sparsevec: \"%s\"", lit)));
            }
        }
    }

    while (SparsevecIsspace(*pt)) {
        pt++;
    }

    if (*pt != '/') {
        ereport(ERROR, (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
                        errmsg("invalid input syntax for type sparsevec: \"%s\"", lit),
                        errdetail("Unexpected end of input.")));
    }

    pt++;

    while (SparsevecIsspace(*pt)) {
        pt++;
    }

    /* Use similar logic as int2vectorin */
    dim = strtol(pt, &stringEnd, 10);

    if (stringEnd == pt)
        ereport(ERROR, (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
                        errmsg("invalid input syntax for type sparsevec: \"%s\"", lit)));

    /* Keep in int range for correct error message later */
    if (dim > INT_MAX) {
        dim = INT_MAX;
    } else if (dim < INT_MIN) {
        dim = INT_MIN;
    }

    pt = stringEnd;

    /* Only whitespace is allowed after the closing brace */
    while (SparsevecIsspace(*pt)) {
        pt++;
    }

    if (*pt != '\0')
        ereport(ERROR,
                (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
                 errmsg("invalid input syntax for type sparsevec: \"%s\"", lit), errdetail("Junk after closing.")));

    CheckDim(dim);
    CheckNnz(nnz, dim);
    CheckExpectedDim(typmod, dim);

    qsort(elements, nnz, sizeof(SparseInputElement), CompareIndices);

    result = InitSparseVector(dim, nnz);
    rvalues = SPARSEVEC_VALUES(result);
    for (int i = 0; i < nnz; i++) {
        result->indices[i] = elements[i].index;
        rvalues[i] = elements[i].value;

        CheckIndex(result->indices, i, dim);
    }

    PG_RETURN_POINTER(result);
}

#define AppendChar(ptr, c) (*(ptr)++ = (c))
#define AppendFloat(ptr, f) ((ptr) += float_to_shortest_decimal_bufn((f), (ptr)))

#define AppendInt(ptr, i)    \
    do {                     \
        pg_ltoa(i, ptr);     \
        while (*ptr != '\0') \
            ptr++;           \
    } while (0)

/*
 * Convert internal representation to textual representation
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_out);
Datum sparsevec_out(PG_FUNCTION_ARGS)
{
    SparseVector *sparsevec = PG_GETARG_SPARSEVEC_P(0);
    float *values = SPARSEVEC_VALUES(sparsevec);
    char *buf;
    char *ptr;

    /*
     * Need:
     *
     * nnz * 10 bytes for index (positive integer)
     *
     * nnz bytes for :
     *
     * nnz * (FLOAT_SHORTEST_DECIMAL_LEN - 1) bytes for
     * float_to_shortest_decimal_bufn
     *
     * nnz - 1 bytes for ,
     *
     * 10 bytes for dimensions
     *
     * 4 bytes for {, }, /, and \0
     */
    buf = (char *)palloc((11 + FLOAT_SHORTEST_DECIMAL_LEN) * sparsevec->nnz + 13);
    ptr = buf;

    AppendChar(ptr, '{');

    for (int i = 0; i < sparsevec->nnz; i++) {
        if (i > 0)
            AppendChar(ptr, ',');

        /* Convert 0-based numbering (C) to 1-based (SQL) */
        AppendInt(ptr, sparsevec->indices[i] + 1);
        AppendChar(ptr, ':');
        AppendFloat(ptr, values[i]);
    }

    AppendChar(ptr, '}');
    AppendChar(ptr, '/');
    AppendInt(ptr, sparsevec->dim);
    *ptr = '\0';

    PG_FREE_IF_COPY(sparsevec, 0);
    PG_RETURN_CSTRING(buf);
}

/*
 * Convert type modifier
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_typmod_in);
Datum sparsevec_typmod_in(PG_FUNCTION_ARGS)
{
    ArrayType *ta = PG_GETARG_ARRAYTYPE_P(0);
    int32 *tl;
    int n;

    tl = ArrayGetIntegerTypmods(ta, &n);

    if (n != 1)
        ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE), errmsg("invalid type modifier")));

    if (*tl < 1)
        ereport(ERROR,
                (errcode(ERRCODE_INVALID_PARAMETER_VALUE), errmsg("dimensions for type sparsevec must be at least 1")));

    if (*tl > SPARSEVEC_MAX_DIM)
        ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
                        errmsg("dimensions for type sparsevec cannot exceed %d", SPARSEVEC_MAX_DIM)));

    PG_RETURN_INT32(*tl);
}

/*
 * Convert external binary representation to internal representation
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_recv);
Datum sparsevec_recv(PG_FUNCTION_ARGS)
{
    StringInfo buf = (StringInfo)PG_GETARG_POINTER(0);
    int32 typmod = PG_GETARG_INT32(2);
    SparseVector *result;
    int32 dim;
    int32 nnz;
    int32 unused;
    float *values;

    dim = pq_getmsgint(buf, sizeof(int32));
    nnz = pq_getmsgint(buf, sizeof(int32));
    unused = pq_getmsgint(buf, sizeof(int32));

    CheckDim(dim);
    CheckNnz(nnz, dim);
    CheckExpectedDim(typmod, dim);

    if (unused != 0)
        ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("expected unused to be 0, not %d", unused)));

    result = InitSparseVector(dim, nnz);
    values = SPARSEVEC_VALUES(result);

    /* Binary representation uses zero-based numbering for indices */
    for (int i = 0; i < nnz; i++) {
        result->indices[i] = pq_getmsgint(buf, sizeof(int32));
        CheckIndex(result->indices, i, dim);
    }

    for (int i = 0; i < nnz; i++) {
        values[i] = pq_getmsgfloat4(buf);
        CheckElement(values[i]);

        if (values[i] == 0)
            ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION),
                            errmsg("binary representation of sparsevec cannot contain zero values")));
    }

    PG_RETURN_POINTER(result);
}

/*
 * Convert internal representation to the external binary representation
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_send);
Datum sparsevec_send(PG_FUNCTION_ARGS)
{
    SparseVector *svec = PG_GETARG_SPARSEVEC_P(0);
    float *values = SPARSEVEC_VALUES(svec);
    StringInfoData buf;

    pq_begintypsend(&buf);
    pq_sendint(&buf, svec->dim, sizeof(int32));
    pq_sendint(&buf, svec->nnz, sizeof(int32));
    pq_sendint(&buf, svec->unused, sizeof(int32));

    /* Binary representation uses zero-based numbering for indices */
    for (int i = 0; i < svec->nnz; i++)
        pq_sendint(&buf, svec->indices[i], sizeof(int32));

    for (int i = 0; i < svec->nnz; i++)
        pq_sendfloat4(&buf, values[i]);

    PG_RETURN_BYTEA_P(pq_endtypsend(&buf));
}

/*
 * Convert sparse vector to sparse vector
 * This is needed to check the type modifier
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec);
Datum sparsevec(PG_FUNCTION_ARGS)
{
    SparseVector *svec = PG_GETARG_SPARSEVEC_P(0);
    int32 typmod = PG_GETARG_INT32(1);

    CheckExpectedDim(typmod, svec->dim);

    PG_RETURN_POINTER(svec);
}

/*
 * Convert dense vector to sparse vector
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_to_sparsevec);
Datum vector_to_sparsevec(PG_FUNCTION_ARGS)
{
    Vector *vec = PG_GETARG_VECTOR_P(0);
    int32 typmod = PG_GETARG_INT32(1);
    SparseVector *result;
    int dim = vec->dim;
    int nnz = 0;
    float *values;
    int j = 0;

    CheckDim(dim);
    CheckExpectedDim(typmod, dim);

    for (int i = 0; i < dim; i++) {
        if (vec->x[i] != 0) {
            nnz++;
        }
    }

    result = InitSparseVector(dim, nnz);
    values = SPARSEVEC_VALUES(result);
    for (int i = 0; i < dim; i++) {
        if (vec->x[i] != 0) {
            /* Safety check */
            if (j >= result->nnz)
                elog(ERROR, "safety check failed");

            result->indices[j] = i;
            values[j] = vec->x[i];
            j++;
        }
    }

    PG_RETURN_POINTER(result);
}

/*
 * Convert half vector to sparse vector
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(halfvec_to_sparsevec);
Datum halfvec_to_sparsevec(PG_FUNCTION_ARGS)
{
    HalfVector *vec = PG_GETARG_HALFVEC_P(0);
    int32 typmod = PG_GETARG_INT32(1);
    SparseVector *result;
    int dim = vec->dim;
    int nnz = 0;
    float *values;
    int j = 0;

    CheckDim(dim);
    CheckExpectedDim(typmod, dim);

    for (int i = 0; i < dim; i++) {
        if (!HalfIsZero(vec->x[i])) {
            nnz++;
        }
    }

    result = InitSparseVector(dim, nnz);
    values = SPARSEVEC_VALUES(result);
    for (int i = 0; i < dim; i++) {
        if (!HalfIsZero(vec->x[i])) {
            /* Safety check */
            if (j >= result->nnz)
                elog(ERROR, "safety check failed");

            result->indices[j] = i;
            values[j] = HalfToFloat4(vec->x[i]);
            j++;
        }
    }

    PG_RETURN_POINTER(result);
}

/*
 * Get the L2 squared distance between sparse vectors
 */
static float SparsevecL2SquaredDistance(SparseVector *a, SparseVector *b)
{
    float *ax = SPARSEVEC_VALUES(a);
    float *bx = SPARSEVEC_VALUES(b);
    float distance = 0.0;
    int bpos = 0;

    for (int i = 0; i < a->nnz; i++) {
        int ai = a->indices[i];
        int bi = -1;

        for (int j = bpos; j < b->nnz; j++) {
            bi = b->indices[j];

            if (ai == bi) {
                float diff = ax[i] - bx[j];

                distance += diff * diff;
            } else if (ai > bi)
                distance += bx[j] * bx[j];

            /* Update start for next iteration */
            if (ai >= bi)
                bpos = j + 1;

            /* Found or passed it */
            if (bi >= ai)
                break;
        }

        if (ai != bi)
            distance += ax[i] * ax[i];
    }

    for (int j = bpos; j < b->nnz; j++)
        distance += bx[j] * bx[j];

    return distance;
}

/*
 * Get the L2 distance between sparse vectors
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_l2_distance);
Datum sparsevec_l2_distance(PG_FUNCTION_ARGS)
{
    SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
    SparseVector *b = PG_GETARG_SPARSEVEC_P(1);

    CheckDims(a, b);

    PG_RETURN_FLOAT8(sqrt((double)SparsevecL2SquaredDistance(a, b)));
}

/*
 * Get the L2 squared distance between sparse vectors
 * This saves a sqrt calculation
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_l2_squared_distance);
Datum sparsevec_l2_squared_distance(PG_FUNCTION_ARGS)
{
    SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
    SparseVector *b = PG_GETARG_SPARSEVEC_P(1);

    CheckDims(a, b);

    PG_RETURN_FLOAT8((double)SparsevecL2SquaredDistance(a, b));
}

/*
 * Get the inner product of two sparse vectors
 */
static float SparsevecInnerProduct(SparseVector *a, SparseVector *b)
{
    float *ax = SPARSEVEC_VALUES(a);
    float *bx = SPARSEVEC_VALUES(b);
    float distance = 0.0;
    int bpos = 0;

    for (int i = 0; i < a->nnz; i++) {
        int ai = a->indices[i];

        for (int j = bpos; j < b->nnz; j++) {
            int bi = b->indices[j];

            /* Only update when the same index */
            if (ai == bi)
                distance += ax[i] * bx[j];

            /* Update start for next iteration */
            if (ai >= bi)
                bpos = j + 1;

            /* Found or passed it */
            if (bi >= ai)
                break;
        }
    }

    return distance;
}

/*
 * Get the inner product of two sparse vectors
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_inner_product);
Datum sparsevec_inner_product(PG_FUNCTION_ARGS)
{
    SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
    SparseVector *b = PG_GETARG_SPARSEVEC_P(1);

    CheckDims(a, b);

    PG_RETURN_FLOAT8((double)SparsevecInnerProduct(a, b));
}

/*
 * Get the negative inner product of two sparse vectors
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_negative_inner_product);
Datum sparsevec_negative_inner_product(PG_FUNCTION_ARGS)
{
    SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
    SparseVector *b = PG_GETARG_SPARSEVEC_P(1);

    CheckDims(a, b);

    PG_RETURN_FLOAT8((double)-SparsevecInnerProduct(a, b));
}

/*
 * Get the cosine distance between two sparse vectors
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_cosine_distance);
Datum sparsevec_cosine_distance(PG_FUNCTION_ARGS)
{
    SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
    SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
    float *ax = SPARSEVEC_VALUES(a);
    float *bx = SPARSEVEC_VALUES(b);
    float norma = 0.0;
    float normb = 0.0;
    double similarity;

    CheckDims(a, b);

    similarity = SparsevecInnerProduct(a, b);

    /* Auto-vectorized */
    for (int i = 0; i < a->nnz; i++)
        norma += ax[i] * ax[i];

    /* Auto-vectorized */
    for (int i = 0; i < b->nnz; i++)
        normb += bx[i] * bx[i];

    /* Use sqrt(a * b) over sqrt(a) * sqrt(b) */
    similarity /= sqrt(static_cast<double>(norma) * static_cast<double>(normb));
#ifdef _MSC_VER
    /* /fp:fast may not propagate NaN */
    if (isnan(similarity))
        PG_RETURN_FLOAT8(NAN);
#endif

    /* Keep in range */
    if (similarity > 1) {
        similarity = 1.0;
    } else if (similarity < -1) {
        similarity = -1.0;
    }

    PG_RETURN_FLOAT8(1.0 - similarity);
}

/*
 * Get the L1 distance between two sparse vectors
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_l1_distance);
Datum sparsevec_l1_distance(PG_FUNCTION_ARGS)
{
    SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
    SparseVector *b = PG_GETARG_SPARSEVEC_P(1);
    float *ax = SPARSEVEC_VALUES(a);
    float *bx = SPARSEVEC_VALUES(b);
    float distance = 0.0;
    int bpos = 0;

    CheckDims(a, b);

    for (int i = 0; i < a->nnz; i++) {
        int ai = a->indices[i];
        int bi = -1;

        for (int j = bpos; j < b->nnz; j++) {
            bi = b->indices[j];

            if (ai == bi)
                distance += fabsf(ax[i] - bx[j]);
            else if (ai > bi)
                distance += fabsf(bx[j]);

            /* Update start for next iteration */
            if (ai >= bi)
                bpos = j + 1;

            /* Found or passed it */
            if (bi >= ai)
                break;
        }

        if (ai != bi)
            distance += fabsf(ax[i]);
    }

    for (int j = bpos; j < b->nnz; j++)
        distance += fabsf(bx[j]);

    PG_RETURN_FLOAT8(static_cast<double>(distance));
}

/*
 * Get the L2 norm of a sparse vector
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_l2_norm);
Datum sparsevec_l2_norm(PG_FUNCTION_ARGS)
{
    SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
    float *ax = SPARSEVEC_VALUES(a);
    double norm = 0.0;

    /* Auto-vectorized */
    for (int i = 0; i < a->nnz; i++)
        norm += (double)ax[i] * (double)ax[i];

    PG_RETURN_FLOAT8(sqrt(norm));
}

static pg_noinline void float_overflow_error(void)
{
    ereport(ERROR, (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), errmsg("value out of range: overflow")));
}

/*
 * Normalize a sparse vector with the L2 norm
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_l2_normalize);
Datum sparsevec_l2_normalize(PG_FUNCTION_ARGS)
{
    SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
    float *ax = SPARSEVEC_VALUES(a);
    double norm = 0;
    SparseVector *result;
    float *rx;

    result = InitSparseVector(a->dim, a->nnz);
    rx = SPARSEVEC_VALUES(result);

    /* Auto-vectorized */
    for (int i = 0; i < a->nnz; i++)
        norm += (double)ax[i] * (double)ax[i];

    norm = sqrt(norm);
    /* Return zero vector for zero norm */
    if (norm > 0) {
        int zeros = 0;

        for (int i = 0; i < a->nnz; i++) {
            result->indices[i] = a->indices[i];
            rx[i] = ax[i] / norm;

            if (isinf(rx[i]))
                float_overflow_error();

            if (rx[i] == 0)
                zeros++;
        }

        /* Allocate a new vector in the unlikely event there are zeros */
        if (zeros > 0) {
            SparseVector *newResult = InitSparseVector(result->dim, result->nnz - zeros);
            float *nx = SPARSEVEC_VALUES(newResult);
            int j = 0;

            for (int i = 0; i < result->nnz; i++) {
                if (rx[i] == 0)
                    continue;

                /* Safety check */
                if (j >= newResult->nnz)
                    elog(ERROR, "safety check failed");

                newResult->indices[j] = result->indices[i];
                nx[j] = rx[i];
                j++;
            }

            pfree(result);

            PG_RETURN_POINTER(newResult);
        }
    }

    PG_RETURN_POINTER(result);
}

/*
 * Internal helper to compare sparse vectors
 */
static int sparsevec_cmp_internal(SparseVector *a, SparseVector *b)
{
    float *ax = SPARSEVEC_VALUES(a);
    float *bx = SPARSEVEC_VALUES(b);
    int nnz = Min(a->nnz, b->nnz);

    /* Check values before dimensions to be consistent with Postgres arrays */
    for (int i = 0; i < nnz; i++) {
        if (a->indices[i] < b->indices[i]) {
            return ax[i] < 0 ? -1 : 1;
        }

        if (a->indices[i] > b->indices[i]) {
            return bx[i] < 0 ? 1 : -1;
        }

        if (ax[i] < bx[i]) {
            return -1;
        }

        if (ax[i] > bx[i]) {
            return 1;
        }
    }

    if (a->nnz < b->nnz && b->indices[nnz] < a->dim) {
        return bx[nnz] < 0 ? 1 : -1;
    }

    if (a->nnz > b->nnz && a->indices[nnz] < b->dim) {
        return ax[nnz] < 0 ? -1 : 1;
    }

    if (a->dim < b->dim) {
        return -1;
    }

    if (a->dim > b->dim) {
        return 1;
    }

    return 0;
}

/*
 * Less than
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_lt);
Datum sparsevec_lt(PG_FUNCTION_ARGS)
{
    SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
    SparseVector *b = PG_GETARG_SPARSEVEC_P(1);

    PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) < 0);
}

/*
 * Less than or equal
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_le);
Datum sparsevec_le(PG_FUNCTION_ARGS)
{
    SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
    SparseVector *b = PG_GETARG_SPARSEVEC_P(1);

    PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) <= 0);
}

/*
 * Equal
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_eq);
Datum sparsevec_eq(PG_FUNCTION_ARGS)
{
    SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
    SparseVector *b = PG_GETARG_SPARSEVEC_P(1);

    PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) == 0);
}

/*
 * Not equal
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_ne);
Datum sparsevec_ne(PG_FUNCTION_ARGS)
{
    SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
    SparseVector *b = PG_GETARG_SPARSEVEC_P(1);

    PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) != 0);
}

/*
 * Greater than or equal
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_ge);
Datum sparsevec_ge(PG_FUNCTION_ARGS)
{
    SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
    SparseVector *b = PG_GETARG_SPARSEVEC_P(1);

    PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) >= 0);
}

/*
 * Greater than
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_gt);
Datum sparsevec_gt(PG_FUNCTION_ARGS)
{
    SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
    SparseVector *b = PG_GETARG_SPARSEVEC_P(1);

    PG_RETURN_BOOL(sparsevec_cmp_internal(a, b) > 0);
}

/*
 * Compare sparse vectors
 */
PGDLLEXPORT PG_FUNCTION_INFO_V1(sparsevec_cmp);
Datum sparsevec_cmp(PG_FUNCTION_ARGS)
{
    SparseVector *a = PG_GETARG_SPARSEVEC_P(0);
    SparseVector *b = PG_GETARG_SPARSEVEC_P(1);

    PG_RETURN_INT32(sparsevec_cmp_internal(a, b));
}