/*
 * Copyright (C) 2025 Huawei Device Co., Ltd.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "credential_message.h"
#include <charconv>
#include "sys/socket.h"
#include "base.h"

using namespace Hdc;

const std::string HDC_CREDENTIAL_SOCKET_SANDBOX_PATH = "/data/hdc/hdc_huks/hdc_credential.socket";

CredentialMessage::CredentialMessage(const std::string& messageStr)
{
    Init(messageStr);
}

void CredentialMessage::Init(const std::string& messageStr)
{
    if (messageStr.empty() || messageStr.length() < MESSAGE_BODY_POS) {
        WRITE_LOG(LOG_FATAL, "messageStr is too short!");
        return;
    }

    int versionInt = messageStr[MESSAGE_VERSION_POS] - '0';
    if (versionInt < METHOD_CRYPTO_KEY || versionInt > METHOD_VERSION_MAX) {
        WRITE_LOG(LOG_FATAL, "Invalid message version %d.", versionInt);
        return;
    }

    messageVersion = versionInt;

    std::string messageMethodStr = messageStr.substr(MESSAGE_METHOD_POS, MESSAGE_METHOD_LEN);
    messageMethodType = StripLeadingZeros(messageMethodStr);

    std::string messageLengthStr = messageStr.substr(MESSAGE_LENGTH_POS, MESSAGE_LENGTH_LEN);
    size_t bodyLength = 0;
    auto [ptr, ec] = std::from_chars(messageLengthStr.data(),
                                     messageLengthStr.data() + messageLengthStr.size(), bodyLength);
    if (ec != std::errc()) {
        bodyLength = 0;
    }
    if (bodyLength == 0 || bodyLength > MESSAGE_STR_MAX_LEN) {
        WRITE_LOG(LOG_FATAL, "Invalid message body length %s.", messageLengthStr.c_str());
        return;
    }

    if (messageStr.length() < MESSAGE_BODY_POS + bodyLength) {
        WRITE_LOG(LOG_FATAL, "messageStr is too short.");
        return;
    }

    messageBodyLen = static_cast<int>(bodyLength);
    messageBody = messageStr.substr(MESSAGE_BODY_POS, bodyLength);
}
CredentialMessage::~CredentialMessage()
{
    if (!messageBody.empty()) {
        memset_s(&messageBody[0], messageBody.size(), 0, messageBody.size());
    }
}

void CredentialMessage::SetMessageVersion(int version)
{
    if (version >= METHOD_CRYPTO_KEY && version <= METHOD_VERSION_MAX) {
        messageVersion = version;
    } else {
        WRITE_LOG(LOG_FATAL, "Invalid message version %d.", version);
    }
}

void CredentialMessage::SetMessageBody(const std::string& body)
{
    if (body.size() > MESSAGE_STR_MAX_LEN) {
        WRITE_LOG(LOG_FATAL, "Message body length exceeds maximum allowed length.");
        return;
    }
    messageBody = body;
    messageBodyLen = static_cast<int>(messageBody.size());
}

std::string CredentialMessage::Construct() const
{
    size_t totalLength = 0;
    totalLength += 1;
    totalLength += MESSAGE_METHOD_LEN;
    totalLength += MESSAGE_LENGTH_LEN;
    totalLength += messageBody.size();

    std::string messageMethodTypeStr = IntToStringWithPadding(messageMethodType, MESSAGE_METHOD_LEN);
    if (messageMethodTypeStr.size() != MESSAGE_METHOD_LEN) {
        WRITE_LOG(LOG_FATAL, "messageMethod length Error!");
        return "";
    }

    std::string messageBodyLenStr = IntToStringWithPadding(messageBodyLen, MESSAGE_LENGTH_LEN);
    if (messageBodyLenStr.empty() || (messageBodyLenStr.size() > MESSAGE_LENGTH_LEN)) {
        WRITE_LOG(LOG_FATAL, "messageBodyLen length must be:%d,now is:%s",
            MESSAGE_LENGTH_LEN, messageBodyLenStr.c_str());
        return "";
    }
    
    std::string result;
    result.reserve(totalLength);
    result.push_back('0' + messageVersion);
    result.append(messageMethodTypeStr);
    result.append(messageBodyLenStr);
    result.append(messageBody);

    if (result.size() != totalLength) {
        WRITE_LOG(LOG_FATAL, "size mismatch. Expected: %zu, Actual: %zu", totalLength, result.size());
        return "";
    }

    return result;
}

bool IsNumeric(const std::string& str)
{
    if (str.empty()) {
        return false;
    }
    for (char ch : str) {
        if (!std::isdigit(ch)) {
            return false;
        }
    }
    return true;
}

int StripLeadingZeros(const std::string& input)
{
    if (input.empty() || input == "0") {
        return 0;
    }
    size_t firstNonZero = input.find_first_not_of('0');
    if (firstNonZero == std::string::npos) {
        return 0;
    }

    std::string numberStr = input.substr(firstNonZero);
    if (!IsNumeric(numberStr)) {
        WRITE_LOG(LOG_FATAL, "StripLeadingZeros: invalid numeric string.");
        return -1;
    }
    
    long value = 0;
    auto [ptr, ec] = std::from_chars(numberStr.data(), numberStr.data() + numberStr.size(), value);
    if (ec != std::errc()) {
        value = 0;
    }
    return static_cast<int>(value);
}

std::vector<uint8_t> String2Uint8(const std::string& str, size_t len)
{
    std::vector<uint8_t> byteData(len);
    for (size_t i = 0; i < len; i++) {
        byteData[i] = static_cast<uint8_t>(str[i]);
    }
    return byteData;
}

std::string IntToStringWithPadding(int length, int maxLen)
{
    std::string str = std::to_string(length);
    if (str.length() > static_cast<size_t>(maxLen)) {
        return "";
    }
    return std::string(static_cast<size_t>(maxLen) - str.length(), '0') + str;
}

void SplitString(const std::string &origString, const std::string &seq,
                 std::vector<std::string> &resultStrings, int count)
{
    if (seq.empty() || origString.empty()) {
        return;
    }

    std::string::size_type p1 = 0;
    std::string::size_type p2 = origString.find(seq);
    int splitCount = 0;

    while (p2 != std::string::npos) {
        if (count >= 0 && splitCount >= count) {
            break;
        }
        if (p2 == p1) {
            ++p1;
            p2 = origString.find(seq, p1);
            continue;
        }

        resultStrings.push_back(origString.substr(p1, p2 - p1));
        ++splitCount;

        p1 = p2 + seq.size();
        p2 = origString.find(seq, p1);
    }

    if (p1 <= origString.size()) {
        resultStrings.push_back(origString.substr(p1));
    }
}

std::string SplicMessageStr(const std::string &str, const size_t methodType, const size_t methodVersion)
{
    if (str.empty()) {
        WRITE_LOG(LOG_FATAL, "Input string is empty.");
        return "";
    }
    const size_t bodyLen = str.size();
    size_t totalLength = MESSAGE_METHOD_POS + MESSAGE_METHOD_LEN +
                         MESSAGE_LENGTH_LEN + bodyLen;

    std::string messageMethodTypeStr = IntToStringWithPadding(methodType, MESSAGE_METHOD_LEN);
    if (messageMethodTypeStr.length() != MESSAGE_METHOD_LEN) {
        WRITE_LOG(LOG_FATAL, "messageMethodTypeStr length must be:%d,now is:%s",
            MESSAGE_METHOD_LEN, messageMethodTypeStr.c_str());
        return "";
    }

    std::string messageBodyLen = IntToStringWithPadding(str.length(), MESSAGE_LENGTH_LEN);
    if (messageBodyLen.empty() || (messageBodyLen.length() > MESSAGE_LENGTH_LEN)) {
        WRITE_LOG(LOG_FATAL, "messageBodyLen length must be:%d,now is:%s", MESSAGE_LENGTH_LEN, messageBodyLen.c_str());
        return "";
    }

    std::string result;
    result.reserve(totalLength);
    result.push_back('0' + methodVersion);
    result.append(messageMethodTypeStr);
    result.append(messageBodyLen);
    result.append(str);
    if (result.size() != totalLength) {
        WRITE_LOG(LOG_FATAL, "size mismatch. Expected: %zu, Actual: %zu", totalLength, result.size());
        return "";
    }
    return result;
}

bool SendMessageByUnixSocket(const int sockfd, const std::string &messageStr)
{
    struct sockaddr_un addr = {.sun_family = AF_UNIX};
    size_t maxPathLen = sizeof(addr.sun_path) - 1;
    size_t pathLen = strlen(HDC_CREDENTIAL_SOCKET_SANDBOX_PATH.c_str());
    if (pathLen > maxPathLen) {
        WRITE_LOG(LOG_FATAL, "Socket path too long.");
        return false;
    }

    if (memcpy_s(addr.sun_path, maxPathLen, HDC_CREDENTIAL_SOCKET_SANDBOX_PATH.c_str(), pathLen) != 0) {
        WRITE_LOG(LOG_FATAL, "Failed to memcpy_st.");
        return false;
    }

    if (connect(sockfd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) < 0) {
        WRITE_LOG(LOG_FATAL, "Failed to connect to socket.");
        return false;
    }

    if (send(sockfd, messageStr.c_str(), messageStr.size(), 0) < 0) {
        WRITE_LOG(LOG_FATAL, "Failed to send message.");
        return false;
    }

    return true;
}

ssize_t RecvMessageByUnixSocket(const int sockfd, char data[], ssize_t size)
{
    ssize_t count = 0;
    ssize_t bytesRead = 0;
    while ((bytesRead = recv(sockfd, data + count, size - 1 - count, 0)) > 0) {
        count += bytesRead;
        if (count >= size - 1) {
            WRITE_LOG(LOG_FATAL, "Failed to read from socket.");
            return false;
        }
    }

    data[count] = '\0';
    if (bytesRead < 0) {
        WRITE_LOG(LOG_FATAL, "Failed to read from socket.");
        return -1;
    }
    return count;
}

ssize_t GetCredential(const std::string &messageStr, char data[], ssize_t size)
{
    if (data == nullptr || size < static_cast<ssize_t>(MESSAGE_STR_MAX_LEN)) {
        WRITE_LOG(LOG_FATAL, "data is null or size:%d out of range", size);
        return -1;
    }

    int sockfd = socket(AF_UNIX, SOCK_STREAM, 0);
    if (sockfd < 0) {
        WRITE_LOG(LOG_FATAL, "Failed to create socket.");
        return -1;
    }

    if (!SendMessageByUnixSocket(sockfd, messageStr)) {
        close(sockfd);
        return -1;
    }

    ssize_t count = RecvMessageByUnixSocket(sockfd, data, size);
    if (count < 0) {
        close(sockfd);
        return -1;
    }
    
    close(sockfd);
    return count;
}