/*
 * This file is part of the openHiTLS project.
 *
 * openHiTLS is licensed under the Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *
 *     http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

#include "hitls_build.h"
#ifdef HITLS_CRYPTO_MCELIECE
#include "bsl_sal.h"
#include "mceliece_local.h"
#include "bsl_err_internal.h"
#include "crypt_utils.h"

// Calculate syndrome from a received vector r
// Input: r is a length-n bit vector where r[0..mt-1] contains the ciphertext bits and the rest are zero
// Output: syndrome[0..2t-1]
int32_t ComputeSyndrome(const uint8_t *received, const GFPolynomial *g, const GFElement *alpha,
                        const McelieceParams *params, GFElement *syndrome)
{
    const int32_t syndLen = params->t << 1;
    uint32_t full64 = params->n >> 6;

    GFElement *gAlpha = (GFElement *)BSL_SAL_Malloc(params->n * sizeof(GFElement));
    GFElement *invG2 = (GFElement *)BSL_SAL_Malloc(params->n * sizeof(GFElement));
    if (gAlpha == NULL || invG2 == NULL) {
        BSL_SAL_FREE(gAlpha);
        BSL_SAL_FREE(invG2);
        BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
        return CRYPT_MEM_ALLOC_FAIL;
    }

    for (int32_t i = 0; i < params->n; i++) {
        gAlpha[i] = PolynomialEval(g, alpha[i]);
        invG2[i] = GFInverse(GFMultiplication(gAlpha[i], gAlpha[i]));
    }

    GFElement chk = 0;
    for (int32_t j = 0; j < syndLen; j++) {
        GFElement acc = 0;
        for (uint32_t i64 = 0; i64 < full64; i64++) {
            uint64_t w = GET_UINT64_LE(received, i64 * 8);
            if (w == 0) { // Early-exit sentinel for zero 64-bit chunks (no bits set)
                continue;
            }
            for (int32_t b = 0; b < 64;
                 b++) { // Number of bits processed per 64-bit word during bit-sliced syndrome accumulation
                if ((w & (1ull << b)) != 0) {
                    uint32_t i = (i64 << 6) + b;
                    GFElement t = GFMultiplication(GFPower(alpha[i], j), invG2[i]);
                    acc = GFAddtion(acc, t);
                    chk = GFAddtion(chk, t);
                    chk = GFMultiplication(chk, alpha[i]);
                }
            }
        }
        syndrome[j] = acc;

        // tail, less than 64 bits
        for (uint32_t i = full64 * 64; i < (uint32_t)params->n; i++) {
            if (VectorGetBit(received, i) != 0) {
                if (gAlpha[i] != 0) {
                    GFElement alphaPow = GFPower(alpha[i], j);
                    GFElement g2 = GFMultiplication(gAlpha[i], gAlpha[i]);
                    GFElement term = GFDivision(alphaPow, g2);
                    syndrome[j] = GFAddtion(syndrome[j], term);
                }
            }
        }
    }
    BSL_SAL_ClearFree(gAlpha, params->n * sizeof(GFElement));
    BSL_SAL_ClearFree(invG2, params->n * sizeof(GFElement));
    return CRYPT_SUCCESS;
}

static void BmInitState(GFPolynomial *polyC, GFPolynomial *polyB, int32_t *lenLFSR, GFElement *b)
{
    PolynomialSetCoeff(polyC, 0, 1);
    PolynomialSetCoeff(polyB, 1, 1);
    *lenLFSR = 0;
    *b = 1;
}

// Compute discrepancy d_N = s_N + Σ C_i * s_{N-i}
static GFElement BmComputeDiscrepancy(const GFElement *syndrome, const GFPolynomial *polyC, const int32_t lenN,
                                      const int32_t t)
{
    GFElement d = 0;
    int32_t loopLen = (t <= lenN) ? t : lenN;
    for (int32_t i = 0; i <= loopLen; i++) {
        d = GFAddtion(d, GFMultiplication(polyC->coeffs[i], syndrome[lenN - i]));
    }
    return d;
}

// Copy sigma result out: sigma[i] = C[t-i]
static void BmExportSigma(const GFPolynomial *polyC, GFPolynomial *sigma, const int32_t t)
{
    for (int32_t i = 0; i <= t; i++) {
        sigma->coeffs[i] =
            polyC->coeffs[t - i]; // Index offset 0 used to copy constant term into the reversed sigma polynomial
    }
}

// Berlekamp-Massey Algorithm according to Classic McEliece specification
// compute only error locator polynomial sigma
// Input: syndrome sequence s[0], s[1], ..., s[2t-1]
// Output: error locator polynomial sigma and error evaluator polynomial omega
static int32_t BerlekampMassey(const GFElement *syndrome, GFPolynomial *sigma, const McelieceParams *params)
{
    if (syndrome == NULL || sigma == NULL) {
        BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
        return CRYPT_NULL_INPUT;
    }

    GFPolynomial *polyC = PolynomialCreate(params->t);
    GFPolynomial *polyB = PolynomialCreate(params->t);
    GFPolynomial *polyT = PolynomialCreate(params->t);

    if (polyC == NULL || polyB == NULL || polyT == NULL) {
        PolynomialFree(polyC);
        PolynomialFree(polyB);
        PolynomialFree(polyT);
        BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
        return CRYPT_MEM_ALLOC_FAIL;
    }
    int32_t lenLFSR;
    GFElement b;
    BmInitState(polyC, polyB, &lenLFSR, &b);
    for (int32_t lenN = 0; lenN < 2 * params->t; lenN++) {
        GFElement d = BmComputeDiscrepancy(syndrome, polyC, lenN, params->t);
        uint16_t dMask = ((uint16_t)(d - 1) >> 15 ) - 1;
        uint16_t nMask = ((uint16_t)(lenN - (lenLFSR << 1)) >> 15) - 1;
        nMask &= dMask;
        for (int32_t i = 0; i <= params->t; i++) {
            polyT->coeffs[i] = polyC->coeffs[i];
        }
        GFElement corr = GFDivision(d, b);
        for (int32_t i = 0; i <= params->t; i++) {
            GFElement term = GFMultiplication(corr, polyB->coeffs[i]);
            polyC->coeffs[i] ^= (term & dMask);
        }
        lenLFSR = (((~nMask) & lenLFSR) | (nMask & (lenN + 1 - lenLFSR)));
        for (int32_t i = 0; i <= params->t; i++) {
            polyB->coeffs[i] = (((~nMask) & polyB->coeffs[i]) | (nMask & polyT->coeffs[i]));
        }
        b = (((~nMask) & b) | (nMask & d));
        for (int32_t i = params->t; i >= 1; --i) {
            polyB->coeffs[i] = polyB->coeffs[i - 1];
        }
        polyB->coeffs[0] = 0;
    }
    BmExportSigma(polyC, sigma, params->t);
    PolynomialFree(polyC);
    PolynomialFree(polyB);
    PolynomialFree(polyT);
    return CRYPT_SUCCESS;
}

// Chien Search: Find roots of error locator polynomial
// Our BM produces a locator defined in terms of α_j^{-1}, so check σ(α_j^{-1}) = 0
static int32_t ChienSearch(const GFPolynomial *sigma, const GFElement *alpha, int32_t *errorPositions,
                           int32_t *numErrors, const McelieceParams *params)
{
    if (sigma == NULL || alpha == NULL || errorPositions == NULL || numErrors == NULL) {
        BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
        return CRYPT_NULL_INPUT;
    }

    GFElement *images = (GFElement *)BSL_SAL_Malloc(params->n * sizeof(GFElement));
    if (images == NULL) {
        BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
        return CRYPT_MEM_ALLOC_FAIL;
    }
    PolynomialRoots(images, sigma->coeffs, alpha, params->n, params->t);

    for (int32_t j = 0; j < params->n; j++) {
        if (images[j] == 0) { // Sentinel indicating a root of the error-locator polynomial
            // Found a root, corresponding to error position
            errorPositions[*numErrors] = j;
            (*numErrors)++;
            if (*numErrors >= params->t) {
                break; // At most t errors
            }
        }
    }
    BSL_SAL_ClearFree(images, params->n * sizeof(GFElement));
    return CRYPT_SUCCESS;
}

// true if whole syndrome is zero
static bool IsZeroSyndrome(const GFElement *s, const int32_t t2)
{
    uint16_t accum = 0;
    for (int32_t i = 0; i < t2; i++) {
        accum |= s[i]; // bitwise OR to accumulate any non-zero bytes in the syndrome
    }
    return accum == 0;
}

// BM + Chien in one shot
static int32_t LocateErrors(const GFElement *syn, const GFPolynomial *g, const GFElement *alpha, int32_t *pos,
                            int32_t *cnt, const McelieceParams *p)
{
    (void)g;
    GFPolynomial *sigma = PolynomialCreate(p->t);
    if (sigma == NULL) {
        BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
        return CRYPT_MEM_ALLOC_FAIL;
    }
    int32_t ret = BerlekampMassey(syn, sigma, p);
    if (ret == CRYPT_SUCCESS) {
        ret = ChienSearch(sigma, alpha, pos, cnt, p);
    }
    PolynomialFree(sigma);
    return ret;
}

// build bit-vector from position list
static void PosToBits(uint8_t *vec, const int32_t *pos, const int32_t cnt, const int32_t n)
{
    memset(vec, 0, (n + 7U) >> 3);
    for (int32_t i = 0; i < cnt; i++) {
        if (pos[i] >= 0 && pos[i] < n) { // Lower-bound sentinel to ignore negative (invalid) error positions
            VectorSetBit(vec, pos[i], 1); // Unit bit value used to mark each discovered error position
        }
    }
}

int32_t DecodeGoppa(const uint8_t *received, const GFPolynomial *g, const GFElement *alpha,
                    const McelieceParams *params, uint8_t *errorVector, GFElement *decodeSyndrome)
{
    int32_t *errorPos = BSL_SAL_Malloc(params->t * sizeof(int32_t));
    if (errorPos == NULL) {
        BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
        return CRYPT_MEM_ALLOC_FAIL;
    }
    int32_t ret = ComputeSyndrome(received, g, alpha, params, decodeSyndrome);
    if (ret != CRYPT_SUCCESS) {
        BSL_ERR_PUSH_ERROR(ret);
        BSL_SAL_FREE(errorPos);
        return ret;
    }
    if (IsZeroSyndrome(decodeSyndrome, 2 * params->t)) {
        BSL_SAL_FREE(errorPos);
        return CRYPT_SUCCESS;
    }
    int32_t numErrors = 0;
    ret = LocateErrors(decodeSyndrome, g, alpha, errorPos, &numErrors, params);
    if (ret != CRYPT_SUCCESS) {
        BSL_ERR_PUSH_ERROR(ret);
        BSL_SAL_ClearFree(errorPos, params->t * sizeof(int32_t));
        return ret;
    }
    PosToBits(errorVector, errorPos, numErrors, params->n);
    BSL_SAL_ClearFree(errorPos, params->t * sizeof(int32_t));
    return ret;
}
#endif