/*
 * Copyright (c) Facebook, Inc. and its affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef OMNI_RUNTIME_BIT_UTIL_H
#define OMNI_RUNTIME_BIT_UTIL_H

#include <stdint.h>
#include <math.h>

namespace omniruntime {
class BitUtil {
public:
    constexpr static uint8_t kZeroBitmasks[] = {
        static_cast<uint8_t>(~(1 << 0)),
        static_cast<uint8_t>(~(1 << 1)),
        static_cast<uint8_t>(~(1 << 2)),
        static_cast<uint8_t>(~(1 << 3)),
        static_cast<uint8_t>(~(1 << 4)),
        static_cast<uint8_t>(~(1 << 5)),
        static_cast<uint8_t>(~(1 << 6)),
        static_cast<uint8_t>(~(1 << 7)),
    };

    template <typename T, typename U> constexpr static inline T RoundUp(T value, U factor)
    {
        return (value + (factor - 1)) / factor * factor;
    }

    constexpr static inline int32_t Nbytes(int32_t bits)
    {
        return RoundUp(bits, 8) / 8;
    }

    constexpr static inline uint64_t Nwords(int32_t bits)
    {
        return RoundUp(bits, 64) / 64;
    }

    constexpr static inline uint64_t LowMask(int32_t bits)
    {
        return (1UL << bits) - 1;
    }

    constexpr static inline uint64_t HighMask(int32_t bits)
    {
        return LowMask(bits) << (64 - bits);
    }

    template <typename T> static inline bool IsBitSet(const T *bits, int32_t idx)
    {
        return bits[idx / (sizeof(bits[0]) * 8)] & (static_cast<T>(1) << (idx & ((sizeof(bits[0]) * 8) - 1)));
    }

    template <typename T> static inline void SetBit(T *bits, uint32_t idx)
    {
        auto bitsAs8Bit = reinterpret_cast<uint8_t *>(bits);
        bitsAs8Bit[idx / 8] |= (1 << (idx % 8));
    }

    template <typename T> static inline void ClearBit(T *bits, uint32_t idx)
    {
        auto bitsAs8Bit = reinterpret_cast<uint8_t *>(bits);
        bitsAs8Bit[idx / 8] &= kZeroBitmasks[idx % 8];
    }

    template <typename T> static inline void SetBit(T *bits, uint32_t idx, bool value)
    {
        value ? SetBit(bits, idx) : ClearBit(bits, idx);
    }

    template <typename PartialWordFunc, typename FullWordFunc>
    static inline void ForEachWord(int32_t begin, int32_t end, PartialWordFunc partialWordFunc,
        FullWordFunc fullWordFunc)
    {
        if (begin >= end) {
            return;
        }
        int32_t firstWord = RoundUp(begin, 64);
        int32_t lastWord = end & ~63L;
        if (lastWord < firstWord) {
            partialWordFunc(lastWord / 64, LowMask(end - lastWord) & HighMask(firstWord - begin));
            return;
        }
        if (begin != firstWord) {
            partialWordFunc(begin / 64, HighMask(firstWord - begin));
        }
        for (int32_t i = firstWord; i + 64 <= lastWord; i += 64) {
            fullWordFunc(i / 64);
        }
        if (end != lastWord) {
            partialWordFunc(lastWord / 64, LowMask(end - lastWord));
        }
    }

    template <typename PartialWordFunc, typename FullWordFunc>
    static inline bool testWords(int32_t begin, int32_t end, PartialWordFunc partialWordFunc, FullWordFunc fullWordFunc)
    {
        if (begin >= end) {
            return true;
        }
        int32_t firstWord = RoundUp(begin, 64);
        int32_t lastWord = end & ~63L;
        if (lastWord < firstWord) {
            return partialWordFunc(lastWord / 64, LowMask(end - lastWord) & HighMask(firstWord - begin));
        }
        if (begin != firstWord) {
            if (!partialWordFunc(begin / 64, HighMask(firstWord - begin))) {
                return false;
            }
        }
        for (int32_t i = firstWord; i + 64 <= lastWord; i += 64) {
            if (!fullWordFunc(i / 64)) {
                return false;
            }
        }
        if (end != lastWord) {
            return partialWordFunc(lastWord / 64, LowMask(end - lastWord));
        }
        return true;
    }

    static inline void FillBits(uint64_t *bits, int32_t begin, int32_t end, bool value)
    {
        ForEachWord(
            begin, end,
            [bits, value](int32_t idx, uint64_t mask) {
                if (value) {
                    bits[idx] |= static_cast<uint64_t>(-1) & mask;
                } else {
                    bits[idx] &= ~mask;
                }
            },
            [bits, value](int32_t idx) { bits[idx] = value ? -1 : 0; });
    }

    static inline int32_t CountBits(const uint64_t *bits, int32_t begin, int32_t end)
    {
        int32_t count = 0;
        ForEachWord(
            begin, end, [&count, bits](int32_t idx, uint64_t mask) { count += __builtin_popcountll(bits[idx] & mask); },
            [&count, bits](int32_t idx) { count += __builtin_popcountll(bits[idx]); });
        return count;
    }

    static inline bool HasBitSet(const uint64_t *bits, int32_t begin, int32_t end)
    {
        return !testWords(
            begin, end,
            [bits](int32_t idx, uint64_t mask) {
                uint64_t word = bits[idx] & mask;
                return !word;
            },
            [bits](int32_t idx) {
                uint64_t word = bits[idx];
                return !word;
            });
    }

    template <typename T> static inline T LoadBits(const uint64_t *source, uint64_t bitOffset, uint8_t numBits)
    {
        constexpr int32_t kBitSize = 8 * sizeof(T);
        auto address = reinterpret_cast<uint64_t>(source) + bitOffset / 8;
        T word = *reinterpret_cast<const T *>(address);
        auto bit = bitOffset & 7;
        if (!bit) {
            return word;
        }
        if (numBits + bit <= kBitSize) {
            return word >> bit;
        }
        uint8_t lastByte = reinterpret_cast<const uint8_t *>(address)[sizeof(T)];
        uint64_t lastBits = static_cast<T>(lastByte) << (kBitSize - bit);
        return (word >> bit) | lastBits;
    }

    template <typename T>
    static inline void StoreBits(uint64_t *target, uint64_t offset, uint64_t word, uint8_t numBits)
    {
        constexpr int32_t kBitSize = 8 * sizeof(T);
        T *address = reinterpret_cast<T *>(reinterpret_cast<uint64_t>(target) + (offset / 8));
        auto bitOffset = offset & 7;
        uint64_t mask = (numBits == 64 ? ~0UL : ((1UL << numBits) - 1)) << bitOffset;
        *address = (*address & ~mask) | (mask & (word << bitOffset));
        if (numBits + bitOffset > kBitSize) {
            uint8_t *lastByteAddress = reinterpret_cast<uint8_t *>(address) + sizeof(T);
            uint8_t lastByteBits = bitOffset + numBits - kBitSize;
            uint8_t lastByteMask = (1 << lastByteBits) - 1;
            *lastByteAddress = (*lastByteAddress & ~lastByteMask) | (lastByteMask & (word >> (kBitSize - bitOffset)));
        }
    }

    static inline void CopyBits(const uint64_t *source, uint64_t sourceOffset, uint64_t *target, uint64_t targetOffset,
        uint64_t numBits)
    {
        uint64_t i = 0;
        for (; i + 64 <= numBits; i += 64) {
            uint64_t word = LoadBits<uint64_t>(source, i + sourceOffset, 64);
            StoreBits<uint64_t>(target, targetOffset + i, word, 64);
        }
        if (i + 32 <= numBits) {
            auto lastWord = LoadBits<uint32_t>(source, sourceOffset + i, 32);
            StoreBits<uint32_t>(target, targetOffset + i, lastWord, 32);
            i += 32;
        }
        if (i + 16 <= numBits) {
            auto lastWord = LoadBits<uint16_t>(source, sourceOffset + i, 16);
            StoreBits<uint16_t>(target, targetOffset + i, lastWord, 16);
            i += 16;
        }
        for (; i < numBits; i += 8) {
            auto copyBits = std::min<uint64_t>(numBits - i, 8);
            auto lastWord = LoadBits<uint8_t>(source, sourceOffset + i, copyBits);
            StoreBits<uint8_t>(target, targetOffset + i, lastWord, copyBits);
        }
    }
};
}

#endif // OMNI_RUNTIME_BIT_UTIL_H