* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "basictypes/BigInteger.h"
void BigInteger::checkRange() const
{
if (mag.size() >= MAX_MAG_LENGTH) {
throw std::runtime_error("BigInteger::checkRange error");
return;
}
}
void BigInteger::destructiveMulAdd(std::vector<int> &x, int y, int z)
{
int64_t ylong = static_cast<int64_t>(y) & LONG_MASK;
int64_t zlong = static_cast<int64_t>(z) & LONG_MASK;
int len = x.size();
int64_t product = 0;
int64_t carry = 0;
for (int i = len - 1; i >= 0; --i) {
product = ylong * (static_cast<int64_t>(x[i]) & LONG_MASK) + carry;
x[i] = static_cast<int>(product);
carry = product >> 32;
}
int64_t sum = (static_cast<int64_t>(x[len - 1]) & LONG_MASK) + zlong;
x[len - 1] = static_cast<int>(sum);
carry = sum >> 32;
for (int i = len - 2; i >= 0; --i) {
sum = (static_cast<int64_t>(x[i]) & LONG_MASK) + carry;
x[i] = static_cast<int>(sum);
carry = sum >> 32;
}
return;
}
int BigInteger::hashCode()
{
int64_t hashCode = 0;
for (size_t i = 0; i < mag.size(); i++) {
hashCode = static_cast<int>(31 * hashCode + (static_cast<int64_t>(mag[i]) & LONG_MASK));
}
return static_cast<int>(hashCode) * signum;
}
bool BigInteger::equals(Object *obj)
{
if (obj == nullptr) {
return false;
}
const auto *bigInteger = reinterpret_cast<BigInteger *>(obj);
if (bigInteger == this) {
return true;
}
if (bigInteger->signum != this->signum) {
return false;
}
if (bigInteger->mag.size() != this->mag.size()) {
return false;
}
for (size_t i = 0; i < this->mag.size(); ++i) {
if (bigInteger->mag[i] != this->mag[i]) {
return false;
}
}
return true;
}
Object *BigInteger::clone()
{
return new BigInteger(this->mag, this->signum);
}
std::string BigInteger::toString(int radix) const
{
if (signum == 0)
return "0";
std::vector<uint32_t> magCopy;
magCopy.reserve(mag.size());
for (int num : mag) {
magCopy.push_back(static_cast<uint32_t>(num));
}
std::string result;
while (!magCopy.empty()) {
uint32_t remainder = 0;
std::vector<uint32_t> quotient;
for (uint32_t word : magCopy) {
uint64_t value = ((uint64_t)remainder << 32) | word;
quotient.push_back((uint32_t)(value / 10));
remainder = value % 10;
}
result += (char)('0' + remainder);
size_t first_non_zero = 0;
while (first_non_zero < quotient.size() && quotient[first_non_zero] == 0) {
++first_non_zero;
}
magCopy.assign(quotient.begin() + first_non_zero, quotient.end());
}
std::reverse(result.begin(), result.end());
if (signum < 0) {
result.insert(result.begin(), '-');
}
return result;
}
BigInteger::BigInteger(const std::string &val, int radix)
{
int cursor = 0;
int len = val.length();
if (radix < CHARACTER_MIN_RADIX || radix > CHARACTER_MAX_RADIX) {
throw std::runtime_error("Radix out of range");
}
if (len == 0) {
throw std::runtime_error("Zero length BigInteger");
}
int sign = 1;
size_t index1 = val.find('-');
size_t index2 = val.find('+');
if (index1 != std::string::npos) {
if (index1 != 0 || index2 != std::string::npos) {
throw std::runtime_error("Illegal embedded sign character");
}
sign = -1;
cursor = 1;
} else if (index2 != std::string::npos) {
if (index2 != 0) {
throw std::runtime_error("Illegal embedded sign character");
}
sign = 1;
cursor = 1;
}
if (cursor == len) {
throw std::runtime_error("Zero length BigInteger");
}
while (cursor < len && isdigit(val[cursor]) && val[cursor] == '0') {
cursor++;
}
if (cursor == len) {
signum = 0;
mag = std::vector<int>();
return;
}
int numDigits = len - cursor;
signum = sign;
int64_t numBits = ((numDigits * bitsPerDigit[radix]) >> 10) + 1;
if (numBits + 31 >= (1L << 32)) {
throw std::runtime_error("BigInteger overflow");
}
int numWords = (numBits + 31) >> 5;
std::vector<int> magnitude(numWords, 0);
int firstGroupLen = numDigits % digitsPerInt[radix];
if (firstGroupLen == 0) {
firstGroupLen = digitsPerInt[radix];
}
std::string group = val.substr(cursor, firstGroupLen);
cursor += firstGroupLen;
magnitude[numWords - 1] = stoi(group, nullptr, radix);
if (magnitude[numWords - 1] < 0) {
throw std::runtime_error("Illegal digit");
}
int superRadix = intRadix[radix];
int groupVal = 0;
while (cursor < len) {
group = val.substr(cursor, digitsPerInt[radix]);
cursor += digitsPerInt[radix];
groupVal = stoi(group, nullptr, radix);
if (groupVal < 0) {
throw std::runtime_error("Illegal digit");
}
destructiveMulAdd(magnitude, superRadix, groupVal);
}
mag = trustedStripLeadingZeroInts(magnitude);
if (mag.size() >= MAX_MAG_LENGTH) {
checkRange();
}
return;
}
BigInteger::BigInteger(const std::string &val) : BigInteger(val, 10)
{
}
BigInteger::BigInteger(String *val)
{
if (val == nullptr) {
throw std::runtime_error("String* is nullptr");
return;
}
BigInteger(val->toString(), 10);
return;
}
BigInteger::BigInteger(int64_t val)
{
if (val == 0) {
signum = 0;
mag = std::vector<int>();
return;
}
if (val < 0) {
val = -val;
signum = -1;
} else {
signum = 1;
}
int highWord = static_cast<int>(val >> 32);
int lowWord = static_cast<int>(val);
if (highWord == 0) {
mag = std::vector<int>(1, static_cast<int>(lowWord));
} else {
mag = std::vector<int>(2);
mag[0] = static_cast<int>(highWord);
mag[1] = static_cast<int>(lowWord);
}
}
std::vector<int> BigInteger::makePositive(const std::vector<int> &val)
{
std::vector<int> result(val);
bool carry = true;
for (int i = result.size() - 1; i >= 0 && carry; --i) {
result[i] = ~result[i];
carry = (result[i] & 1) ? false : true;
result[i] += 1;
}
if (carry) {
result.insert(result.begin(), 1);
}
return result;
}
std::vector<int> BigInteger::trustedStripLeadingZeroInts(const std::vector<int> &val)
{
int keep = 0;
int vlen = val.size();
while (keep < vlen && val[keep] == 0) {
keep++;
}
if (keep == 0) {
return val;
} else {
return std::vector<int>(val.begin() + keep, val.end());
}
}
void BigInteger::checkRange()
{
if (mag.size() > MAX_MAG_LENGTH) {
throw std::runtime_error("BigInteger out of range");
}
}
BigInteger::BigInteger(const std::vector<int> &val)
{
if (val.empty()) {
throw std::runtime_error("Zero length BigInteger");
}
if (val[0] < 0) {
mag = makePositive(val);
signum = -1;
} else {
mag = trustedStripLeadingZeroInts(val);
signum = (mag.empty() ? 0 : 1);
}
if (!mag.empty() && mag.size() >= MAX_MAG_LENGTH) {
checkRange();
}
}
BigInteger::BigInteger(const std::vector<int> &magnitude, int signum)
: mag(magnitude), signum(magnitude.empty() ? 0 : signum)
{
if (!mag.empty() && mag.size() >= MAX_MAG_LENGTH) {
checkRange();
}
}
BigInteger* BigInteger::valueOf(int64_t val)
{
return new BigInteger(val);
}
BigInteger* BigInteger::valueOf(std::vector<int> &val)
{
if (val.size() == 0) {
return new BigInteger((int64_t)0);
}
if (val[0] > 0) {
return new BigInteger(val, 1);
}
return new BigInteger(val);
}
int BigInteger::bitCount(int value) const
{
int count = 0;
while (value) {
count += value & 1;
value >>= 1;
}
return count;
}
int BigInteger::bitLengthForInt(int value) const
{
if (value == 0) {
return 0;
}
int count = 0;
while (value >>= 1) {
count++;
}
return count + 1;
}
int BigInteger::bitLength() const
{
int len = mag.size();
if (len == 0) {
return 0;
}
int magBitLength = ((len - 1) << 5) + bitLengthForInt(mag[0]);
if (signum < 0) {
bool pow2 = (bitCount(mag[0]) == 1);
for (size_t i = 1; i < mag.size() && pow2; ++i) {
pow2 = (mag[i] == 0);
}
if (pow2) {
magBitLength -= 1;
}
}
return magBitLength;
}
int BigInteger::firstNonzeroIntNum()
{
int fn = firstNonzeroIntNumIndex - 2;
if (fn == -2) {
fn = 0;
int mlen = mag.size();
int i;
for (i = mlen - 1; i >= 0 && mag[i] == 0; --i)
;
fn = (mlen - i - 1);
firstNonzeroIntNumIndex = fn + 2;
}
return fn;
}
int BigInteger::getInt(int n)
{
if (n < 0) {
return 0;
}
if (n >= static_cast<int>(mag.size())) {
return signum < 0 ? -1 : 0;
}
int magInt = mag[mag.size() - n - 1];
if (signum >= 0) {
return magInt;
}
if (n <= firstNonzeroIntNum()) {
return -magInt;
}
return ~magInt;
}
std::vector<unsigned char> BigInteger::toByteArray()
{
if (mag.empty()) {
return std::vector<unsigned char>();
}
int byteLen = (bitLength() / 8) + 1;
std::vector<unsigned char> byteArray(byteLen, 0x00);
int intIndex = 0;
int bytesCopied = 4;
int nextInt = 0;
for (int i = byteLen - 1; i >= 0; --i) {
if (bytesCopied == 4) {
nextInt = getInt(intIndex++);
bytesCopied = 1;
} else {
nextInt >>= 8;
bytesCopied++;
}
byteArray[i] = static_cast<unsigned char>(nextInt & 0xFF);
}
return byteArray;
}
void BigInteger::setByteArray(uint8_t *buffer, int capacity, int offset, int length)
{
if (buffer == nullptr || offset + length > capacity) {
throw std::runtime_error("BigInteger::setByteArray error");
return;
}
auto first = buffer + offset;
if (first[0] < 0) {
mag = makePositive(first, length);
signum = -1;
} else {
mag = stripLeadingZeroBytes(first, length);
signum = (mag.size() == 0 ? 0 : 1);
}
if (mag.size() >= MAX_MAG_LENGTH) {
checkRange();
}
return;
}
std::vector<int> BigInteger::stripLeadingZeroBytes(const uint8_t* a, const int byteLength)
{
if (byteLength == 0) {
return std::vector<int>();
}
size_t keep = 0;
while (keep < byteLength && a[keep] == 0) {
++keep;
}
if (keep == byteLength) {
return std::vector<int>();
}
size_t intLength = (byteLength - keep + 3) / 4;
std::vector<int> result(intLength, 0);
size_t b = byteLength - 1;
for (size_t i = intLength - 1; i < intLength; --i) {
result[i] = a[b--] & 0xFF;
size_t bytesRemaining = b - keep + 1;
size_t bytesToTransfer = std::min<size_t>(3, bytesRemaining);
for (size_t j = 8; j <= (bytesToTransfer << 3); j += 8) {
if (b >= keep) {
result[i] |= (a[b--] & 0xFF) << j;
} else {
break;
}
}
}
return result;
}
std::vector<int> BigInteger::makePositive(const uint8_t* a, const int byteLength)
{
int keep = 0;
while (keep < byteLength && a[keep] == -1) {
keep++;
}
int k = keep;
while (k < byteLength && a[k] == 0) {
k++;
}
int extraByte = (k == byteLength) ? 1 : 0;
int intLength = (byteLength - keep + extraByte + 3) / 4;
std::vector<int> result(intLength, 0);
int b = byteLength - 1;
for (int i = intLength - 1; i >= 0; i--) {
int numBytesToTransfer = std::min(3, b - keep + 1);
if (numBytesToTransfer < 0) {
numBytesToTransfer = 0;
}
int value = 0;
for (int j = 0; j < numBytesToTransfer; j++) {
value |= (static_cast<int>(a[b - j] & 0xFF)) << (8 * j);
}
int maskShift = 8 * (3 - numBytesToTransfer);
int mask = 0xFFFFFFFF << maskShift;
result[i] = (~value) & mask;
}
for (int i = intLength - 1; i >= 0; i--) {
result[i] = (result[i] + 1) & 0xFFFFFFFF;
if (result[i] != 0) {
break;
}
}
while (result.size() > 1 && result.front() == 0) {
result.erase(result.begin());
}
return result;
}
BigInteger& BigInteger::operator=(const BigInteger& other)
{
if (this != &other) {
this->mag = other.mag;
this->signum = other.signum;
}
return *this;
}
bool BigInteger::operator==(const BigInteger& other) const
{
if (this->signum == other.signum && std::equal(this->mag.begin(), this->mag.end(), other.mag.begin(), other.mag.end())) {
return true;
}
return false;
}
bool BigInteger::operator!=(const BigInteger& other) const
{
if (this->signum != other.signum || !std::equal(this->mag.begin(), this->mag.end(), other.mag.begin(), other.mag.end())) {
return true;
}
return false;
}
const BigInteger ZERO = BigInteger(static_cast<int64_t>(0));
const BigInteger ONE = BigInteger(static_cast<int64_t>(1));
const BigInteger TEN = BigInteger(static_cast<int64_t>(10));