#include "sql/recover_module/module.h"
#include <cstddef>
#include <cstdint>
#include <ostream>
#include <string>
#include <utility>
#include <vector>
#include "base/check_op.h"
#include "base/strings/strcat.h"
#include "base/strings/string_piece.h"
#include "base/strings/string_util.h"
#include "sql/recover_module/cursor.h"
#include "sql/recover_module/parsing.h"
#include "sql/recover_module/table.h"
#include "third_party/sqlite/sqlite3.h"
namespace sql {
namespace recover {
namespace {
static constexpr int kModuleNameArgument = 0;
static constexpr int kVirtualTableDbNameArgument = 1;
static constexpr int kVirtualTableNameArgument = 2;
static constexpr int kBackingTableSpecArgument = 3;
static constexpr int kFirstColumnArgument = 4;
std::vector<RecoveredColumnSpec> ParseColumnSpecs(int argc,
const char* const* argv) {
std::vector<RecoveredColumnSpec> result;
DCHECK_GE(argc, kFirstColumnArgument);
result.reserve(argc - kFirstColumnArgument + 1);
for (int i = kFirstColumnArgument; i < argc; ++i) {
result.emplace_back(ParseColumnSpec(argv[i]));
if (!result.back().IsValid()) {
result.clear();
break;
}
}
return result;
}
int ModuleCreate(sqlite3* sqlite_db,
void* ,
int argc,
const char* const* argv,
sqlite3_vtab** result_sqlite_table,
char** ) {
DCHECK(sqlite_db != nullptr);
if (argc <= kFirstColumnArgument) {
return SQLITE_ERROR;
}
DCHECK(argv != nullptr);
DCHECK(result_sqlite_table != nullptr);
DCHECK_EQ("recover", base::StringPiece(argv[kModuleNameArgument]));
base::StringPiece db_name(argv[kVirtualTableDbNameArgument]);
if (db_name != "temp") {
return SQLITE_ERROR;
}
base::StringPiece table_name(argv[kVirtualTableNameArgument]);
if (!base::StartsWith(table_name, "recover_")) {
return SQLITE_ERROR;
}
TargetTableSpec backing_table_spec =
ParseTableSpec(argv[kBackingTableSpecArgument]);
if (!backing_table_spec.IsValid()) {
return SQLITE_ERROR;
}
std::vector<RecoveredColumnSpec> column_specs = ParseColumnSpecs(argc, argv);
if (column_specs.empty()) {
return SQLITE_ERROR;
}
auto [sqlite_status, table] = VirtualTable::Create(
sqlite_db, std::move(backing_table_spec), std::move(column_specs));
if (sqlite_status != SQLITE_OK)
return sqlite_status;
{
std::string create_table_sql = table->ToCreateTableSql();
sqlite3_declare_vtab(sqlite_db, create_table_sql.c_str());
}
*result_sqlite_table = table->SqliteTable();
table.release();
return SQLITE_OK;
}
int ModuleConnect(sqlite3* sqlite_db,
void* client_data,
int argc,
const char* const* argv,
sqlite3_vtab** result_sqlite_table,
char** error_string) {
return ModuleCreate(sqlite_db, client_data, argc, argv, result_sqlite_table,
error_string);
}
int ModuleBestIndex(sqlite3_vtab* sqlite_table,
sqlite3_index_info* index_info) {
DCHECK(sqlite_table != nullptr);
DCHECK(index_info != nullptr);
for (int i = 0; i < index_info->nConstraint; ++i) {
if (index_info->aConstraint[i].usable == static_cast<char>(false))
continue;
index_info->aConstraintUsage[i].argvIndex = 0;
index_info->aConstraintUsage[i].omit = false;
}
index_info->orderByConsumed = static_cast<int>(false);
index_info->idxStr = nullptr;
index_info->idxNum = 0;
index_info->needToFreeIdxStr = static_cast<int>(false);
return SQLITE_OK;
}
int ModuleDisconnect(sqlite3_vtab* sqlite_table) {
DCHECK(sqlite_table != nullptr);
VirtualTable* const table = VirtualTable::FromSqliteTable(sqlite_table);
delete table;
return SQLITE_OK;
}
int ModuleDestroy(sqlite3_vtab* sqlite_table) {
return ModuleDisconnect(sqlite_table);
}
int ModuleOpen(sqlite3_vtab* sqlite_table,
sqlite3_vtab_cursor** result_sqlite_cursor) {
DCHECK(sqlite_table != nullptr);
DCHECK(result_sqlite_cursor != nullptr);
VirtualTable* const table = VirtualTable::FromSqliteTable(sqlite_table);
VirtualCursor* const cursor = table->CreateCursor();
*result_sqlite_cursor = cursor->SqliteCursor();
return SQLITE_OK;
}
int ModuleClose(sqlite3_vtab_cursor* sqlite_cursor) {
DCHECK(sqlite_cursor != nullptr);
VirtualCursor* const cursor = VirtualCursor::FromSqliteCursor(sqlite_cursor);
delete cursor;
return SQLITE_OK;
}
int ModuleFilter(sqlite3_vtab_cursor* sqlite_cursor,
int ,
const char* ,
int ,
sqlite3_value** ) {
DCHECK(sqlite_cursor != nullptr);
VirtualCursor* const cursor = VirtualCursor::FromSqliteCursor(sqlite_cursor);
return cursor->First();
}
int ModuleNext(sqlite3_vtab_cursor* sqlite_cursor) {
DCHECK(sqlite_cursor != nullptr);
VirtualCursor* const cursor = VirtualCursor::FromSqliteCursor(sqlite_cursor);
return cursor->Next();
}
int ModuleEof(sqlite3_vtab_cursor* sqlite_cursor) {
DCHECK(sqlite_cursor != nullptr);
VirtualCursor* const cursor = VirtualCursor::FromSqliteCursor(sqlite_cursor);
return cursor->IsValid() ? 0 : 1;
}
int ModuleColumn(sqlite3_vtab_cursor* sqlite_cursor,
sqlite3_context* result_context,
int column_index) {
DCHECK(sqlite_cursor != nullptr);
DCHECK(result_context != nullptr);
VirtualCursor* const cursor = VirtualCursor::FromSqliteCursor(sqlite_cursor);
DCHECK(cursor->IsValid()) << "SQLite called xRowid() without a valid cursor";
return cursor->ReadColumn(column_index, result_context);
}
int ModuleRowid(sqlite3_vtab_cursor* sqlite_cursor,
sqlite3_int64* result_rowid) {
DCHECK(sqlite_cursor != nullptr);
DCHECK(result_rowid != nullptr);
VirtualCursor* const cursor = VirtualCursor::FromSqliteCursor(sqlite_cursor);
DCHECK(cursor->IsValid()) << "SQLite called xRowid() without a valid cursor";
*result_rowid = cursor->RowId();
return SQLITE_OK;
}
constexpr int kSqliteModuleApiVersion = 1;
constexpr sqlite3_module kSqliteModule = {
kSqliteModuleApiVersion,
&ModuleCreate,
&ModuleConnect,
&ModuleBestIndex,
&ModuleDisconnect,
&ModuleDestroy,
&ModuleOpen,
&ModuleClose,
&ModuleFilter,
&ModuleNext,
&ModuleEof,
&ModuleColumn,
&ModuleRowid,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
};
}
int RegisterRecoverExtension(sqlite3* db) {
return sqlite3_create_module_v2(db, "recover", &kSqliteModule,
nullptr,
nullptr);
}
}
}