* Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
* Description: string some common operators
*/
#ifndef OMNI_RUNTIME_STRING_UTIL_H
#define OMNI_RUNTIME_STRING_UTIL_H
#include <iostream>
#include <string>
#include <memory>
#include <locale>
#include <codecvt>
#include <libboundscheck/include/securec.h>
#include "util/utf8_util.h"
namespace omniruntime::codegen::function {
static std::string REPLACE_ERR_MSG = "Replace failed";
static std::string CONCAT_ERR_MSG = "Concat failed";
static std::string CONCAT_WS_ERR_MSG = "ConcatWs failed";
static constexpr uint8_t EMPTY[] = "";
static int32_t STEP = static_cast<int>('a') - static_cast<int>('A');
static uint8_t BytesOfCodePointInUTF8[] = {
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
4, 4, 4, 4, 4,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
};
class StringUtil {
public:
static inline std::wstring ToWideString(std::string &s)
{
std::wstring_convert<std::codecvt_utf8<wchar_t>> convert;
return convert.from_bytes(s);
}
static inline const char *CastStrStr(bool *hasErr, const char *str, int32_t srcWidth, int32_t strLen,
int32_t *outLen, int32_t dstWidth)
{
int32_t chCount = std::min(srcWidth, dstWidth);
int32_t dstLen = 0;
int32_t count = 0;
while (dstLen < strLen && count < chCount) {
int32_t charLen = omniruntime::Utf8Util::LengthOfCodePoint(str[dstLen]);
if (charLen < 0) {
*hasErr = true;
*outLen = 0;
return nullptr;
}
dstLen += charLen;
count++;
}
*outLen = dstLen;
return str;
}
static inline const char *ConcatCharDiffWidths(int64_t contextPtr, const char *ap, int32_t aWidth, int32_t apLen,
const char *bp, int32_t bpLen, bool *hasErr, int32_t *outLen)
{
int32_t aPaddingCount = bpLen > 0 ? aWidth - omniruntime::Utf8Util::CountCodePoints(ap, apLen) : 0;
*outLen = apLen + aPaddingCount + bpLen;
if (*outLen <= 0) {
*outLen = 0;
return reinterpret_cast<const char *>(EMPTY);
}
auto ret = ArenaAllocatorMalloc(contextPtr, *outLen + 1);
errno_t res1 = memcpy_s(ret, *outLen + 1, ap, apLen);
errno_t res2 = memset_s(ret + apLen, *outLen - apLen + 1, ' ', aPaddingCount);
errno_t res3 = memcpy_s(ret + apLen + aPaddingCount, *outLen - (apLen + aPaddingCount) + 1, bp, bpLen);
if (res1 != EOK || res2 != EOK || res3 != EOK) {
*hasErr = true;
*outLen = 0;
return nullptr;
}
return ret;
}
static inline const char *ConcatStrDiffWidths(int64_t contextPtr, const char *ap, int32_t apLen, const char *bp,
int32_t bpLen, bool *hasErr, int32_t *outLen)
{
*outLen = apLen + bpLen;
if (*outLen <= 0) {
*outLen = 0;
return reinterpret_cast<const char *>(EMPTY);
}
auto ret = ArenaAllocatorMalloc(contextPtr, *outLen + 1);
errno_t res1 = memcpy_s(ret, *outLen + 1, ap, apLen);
errno_t res2 = memcpy_s(ret + apLen, *outLen - apLen + 1, bp, bpLen);
if (res1 != EOK || res2 != EOK) {
*hasErr = true;
*outLen = 0;
return nullptr;
}
return ret;
}
static inline const char *ConcatWsStrDiffWidths(int64_t contextPtr, const char *separator, int32_t separatorLen,
const char *ap, int32_t apLen, const char *bp, int32_t bpLen, bool *hasErr, int32_t *outLen)
{
*outLen = apLen + separatorLen + bpLen;
if (*outLen <= 0) {
*outLen = 0;
return reinterpret_cast<const char *>(EMPTY);
}
auto ret = ArenaAllocatorMalloc(contextPtr, *outLen + 1);
errno_t res1 = memcpy_s(ret, *outLen + 1, ap, apLen);
errno_t res2 = memcpy_s(ret + apLen, *outLen + 1 - apLen, separator, separatorLen);
errno_t res3 = memcpy_s(ret + apLen + separatorLen, *outLen + 1 - apLen - separatorLen, bp, bpLen);
if (res1 != EOK || res2 != EOK || res3 != EOK) {
*hasErr = true;
*outLen = 0;
return nullptr;
}
return ret;
}
static inline bool ConcatWsAppend(int64_t contextPtr, const char *separator, int32_t separatorLen,
const char *&tmp, int32_t &tmpLen, bool &tmpSet, const char *s, int32_t sLen, bool *hasErr, int32_t *outLen)
{
if (!tmpSet) {
tmp = s;
tmpLen = sLen;
tmpSet = true;
*outLen = tmpLen;
return true;
}
const char *next = ConcatWsStrDiffWidths(contextPtr, separator, separatorLen, tmp, tmpLen, s, sLen, hasErr, outLen);
if (*hasErr) {
return false;
}
tmp = next;
tmpLen = *outLen;
return true;
}
static inline const char *ReplaceWithSearchNotEmpty(int64_t contextPtr, const char *str, int32_t strLen,
const char *searchStr, int32_t searchLen, const char *replaceStr, int32_t replaceLen, bool *hasErr,
int32_t *outLen)
{
if (strLen == 0) {
*outLen = 0;
return reinterpret_cast<const char *>(EMPTY);
}
std::string s = std::string(str, strLen);
std::string search = std::string(searchStr, searchLen);
std::string replace = std::string(replaceStr, replaceLen);
std::string::size_type matchIndex = 0;
if (replaceLen == 0) {
while ((matchIndex = s.find(search, matchIndex)) != std::string::npos) {
s = s.substr(0, matchIndex) + s.substr(matchIndex + searchLen);
}
} else {
while ((matchIndex = s.find(search, matchIndex)) != std::string::npos) {
s.replace(matchIndex, searchLen, replace);
matchIndex += replaceLen;
}
}
*outLen = static_cast<int32_t>(s.length());
auto ret = ArenaAllocatorMalloc(contextPtr, *outLen + 1);
error_t res = memcpy_s(ret, *outLen + 1, s.c_str(), s.length());
if (res != EOK) {
*hasErr = true;
*outLen = 0;
return nullptr;
}
return ret;
}
static inline const char *ReplaceWithSearchEmpty(int64_t contextPtr, const char *str, int32_t strLen,
const char *replaceStr, int32_t replaceLen, bool *hasErr, int32_t *outLen)
{
int32_t strCodePoints = omniruntime::Utf8Util::CountCodePoints(str, strLen);
*outLen = strLen + (strCodePoints + 1) * replaceLen;
auto ret = ArenaAllocatorMalloc(contextPtr, *outLen + 1);
int32_t indexBuffer = 0;
errno_t res;
for (int32_t index = 0; index < strLen;) {
res = memcpy_s(ret + indexBuffer, *outLen - indexBuffer + 1, replaceStr, replaceLen);
if (res != EOK) {
*hasErr = true;
*outLen = 0;
return nullptr;
}
indexBuffer += replaceLen;
int32_t codePointLength = omniruntime::Utf8Util::LengthOfCodePoint(*(str + index));
res = memcpy_s(ret + indexBuffer, *outLen - indexBuffer + 1, str + index, codePointLength);
if (res != EOK) {
*hasErr = true;
*outLen = 0;
return nullptr;
}
indexBuffer += codePointLength;
index += codePointLength;
}
res = memcpy_s(ret + indexBuffer, *outLen - indexBuffer + 1, replaceStr, replaceLen);
if (res != EOK) {
*hasErr = true;
*outLen = 0;
return nullptr;
}
return ret;
}
static inline void TrimString(std::string &str)
{
str.erase(0, str.find_first_not_of(' '));
str.erase(str.find_last_not_of(' ') + 1);
}
static inline bool StrContainsStr(const char *srcStr, int32_t srcLen, const char *matchStr, int32_t matchLen)
{
std::unique_ptr<int[]>next = std::make_unique<int[]>(matchLen);
next[0] = -1;
int i = 0;
int j = -1;
while (i < matchLen - 1) {
if (j == -1 || matchStr[i] == matchStr[j]) {
i++;
j++;
next[i] = j;
} else {
j = next[j];
}
}
i = 0;
j = 0;
while (i < srcLen && j < matchLen) {
if (j == -1 || srcStr[i] == matchStr[j]) {
i++;
j++;
} else {
j = next[j];
}
}
return j == matchLen;
}
static inline int32_t NumChars(const char *str, int32_t strLen)
{
int32_t len = 0;
int32_t i = 0;
while (i < strLen) {
len += 1;
int32_t offset = str[i] & 0xFF;
uint8_t numBytes = BytesOfCodePointInUTF8[offset];
i += numBytes == 0 ? 1 : numBytes;
}
return len;
}
};
}
#endif