#include "sql/statement.h"
#include <stddef.h>
#include <stdint.h>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
#include "base/check.h"
#include "base/check_op.h"
#include "base/compiler_specific.h"
#include "base/containers/span.h"
#include "base/dcheck_is_on.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/memory/scoped_refptr.h"
#include "base/metrics/histogram_functions.h"
#include "base/numerics/safe_conversions.h"
#include "base/sequence_checker.h"
#include "base/strings/string_util.h"
#include "base/strings/string_view_util.h"
#include "base/strings/utf_string_conversions.h"
#include "base/threading/scoped_blocking_call.h"
#include "base/time/time.h"
#include "base/timer/elapsed_timer.h"
#include "base/trace_event/trace_event.h"
#include "sql/database.h"
#include "sql/sqlite_result_code.h"
#include "sql/sqlite_result_code_values.h"
#include "third_party/sqlite/sqlite3.h"
namespace sql {
int64_t Statement::TimeToSqlValue(base::Time time) {
return time.ToDeltaSinceWindowsEpoch().InMicroseconds();
}
std::string GetSqlStatementStringForTracing(sqlite3_stmt* stmt) {
#if defined(SQLITE_OMIT_TRACE)
return sqlite3_sql(stmt);
#else
return sqlite3_expanded_sql(stmt);
#endif
}
Statement::Statement()
: ref_(base::MakeRefCounted<Database::StatementRef>(nullptr,
nullptr,
false)) {}
Statement::Statement(scoped_refptr<Database::StatementRef> ref)
: ref_(std::move(ref)) {}
Statement::~Statement() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
Reset(true);
}
void Statement::Assign(scoped_refptr<Database::StatementRef> ref) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
Reset(true);
ref_ = std::move(ref);
}
void Statement::Clear() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
Assign(base::MakeRefCounted<Database::StatementRef>(nullptr, nullptr, false));
last_sqlite_result_code_ = std::nullopt;
}
bool Statement::CheckValid() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DLOG_IF(FATAL, !ref_->was_valid())
<< "Cannot call mutating statements on an invalid statement.";
return is_valid();
}
void Statement::CheckCanReadColumn(int column_index) const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
CHECK(is_valid());
CHECK(last_sqlite_result_code_.has_value());
CHECK_EQ(*last_sqlite_result_code_, SqliteResultCode::kRow);
CHECK_GE(column_index, 0);
CHECK_LT(column_index, sqlite3_data_count(ref_->stmt()));
}
SqliteResultCode Statement::StepInternal() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!CheckValid())
return SqliteResultCode::kError;
base::ElapsedTimer timer;
if (!time_spent_stepping_) {
time_spent_stepping_ = base::TimeDelta();
TRACE_EVENT_BEGIN("sql", "Database::Statement",
ref_->database()->GetTracingNamedTrack(),
timer.start_time(), "statement",
GetSqlStatementStringForTracing(ref_->stmt()));
}
std::optional<base::ScopedBlockingCall> scoped_blocking_call;
ref_->InitScopedBlockingCall(FROM_HERE, &scoped_blocking_call);
auto sqlite_result_code = ToSqliteResultCode(sqlite3_step(ref_->stmt()));
auto elapsed = timer.Elapsed();
ref_->database()->RecordTimingHistogram("Sql.Statement.StepTime.", elapsed);
*time_spent_stepping_ += elapsed;
return CheckSqliteResultCode(sqlite_result_code);
}
void Statement::ReportQueryExecutionMetrics() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
const int kResetVMStepsToZero = 1;
const int vm_steps = sqlite3_stmt_status(
ref_->stmt(), SQLITE_STMTSTATUS_VM_STEP, kResetVMStepsToZero);
const Database* database = ref_->database();
if (vm_steps > 0 && !database->histogram_tag().empty()) {
const std::string histogram_name =
"Sql.Statement." + database->histogram_tag() + ".VMSteps";
base::UmaHistogramCounts10000(histogram_name, vm_steps);
}
if (time_spent_stepping_) {
TRACE_EVENT_END("sql", database->GetTracingNamedTrack(), "statement",
GetSqlStatementStringForTracing(ref_->stmt()));
database->RecordTimingHistogram("Sql.Statement.ExecutionTime.",
*time_spent_stepping_);
}
}
bool Statement::Run() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
#if DCHECK_IS_ON()
DCHECK(!run_called_) << "Run() must be called exactly once";
run_called_ = true;
DCHECK(!step_called_) << "Run() must not be mixed with Step()";
#endif
return StepInternal() == SqliteResultCode::kDone;
}
bool Statement::Step() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
#if DCHECK_IS_ON()
DCHECK(!run_called_) << "Run() must not be mixed with Step()";
step_called_ = true;
#endif
return StepInternal() == SqliteResultCode::kRow;
}
void Statement::Reset(bool clear_bound_vars) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
std::optional<base::ScopedBlockingCall> scoped_blocking_call;
ref_->InitScopedBlockingCall(FROM_HERE, &scoped_blocking_call);
if (is_valid()) {
ReportQueryExecutionMetrics();
ref_->Reset(clear_bound_vars);
}
if (ref_->database())
ref_->database()->ReleaseCacheMemoryIfNeeded(false);
last_sqlite_result_code_ = std::nullopt;
#if DCHECK_IS_ON()
run_called_ = false;
step_called_ = false;
#endif
time_spent_stepping_ = std::nullopt;
}
bool Statement::Succeeded() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return is_valid() && last_sqlite_result_code_.has_value() &&
IsSqliteSuccessCode(*last_sqlite_result_code_);
}
void Statement::WillBindParameter(int param_index) {
DCHECK_GE(param_index, 0);
DCHECK_LT(param_index, sqlite3_bind_parameter_count(ref_->stmt()))
<< "Invalid parameter index";
ref_->ClearBlobMemory(param_index);
}
void Statement::BindNull(int param_index) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
#if DCHECK_IS_ON()
DCHECK(!run_called_) << __func__ << " must not be called after Run()";
DCHECK(!step_called_) << __func__ << " must not be called after Step()";
#endif
if (!is_valid())
return;
WillBindParameter(param_index);
int sqlite_result_code = sqlite3_bind_null(ref_->stmt(), param_index + 1);
DCHECK_EQ(sqlite_result_code, SQLITE_OK);
}
void Statement::BindBool(int param_index, bool val) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return BindInt64(param_index, val ? 1 : 0);
}
void Statement::BindInt(int param_index, int val) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
#if DCHECK_IS_ON()
DCHECK(!run_called_) << __func__ << " must not be called after Run()";
DCHECK(!step_called_) << __func__ << " must not be called after Step()";
#endif
if (!is_valid())
return;
WillBindParameter(param_index);
int sqlite_result_code = sqlite3_bind_int(ref_->stmt(), param_index + 1, val);
DCHECK_EQ(sqlite_result_code, SQLITE_OK);
}
void Statement::BindInt64(int param_index, int64_t val) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
#if DCHECK_IS_ON()
DCHECK(!run_called_) << __func__ << " must not be called after Run()";
DCHECK(!step_called_) << __func__ << " must not be called after Step()";
#endif
if (!is_valid())
return;
WillBindParameter(param_index);
int sqlite_result_code =
sqlite3_bind_int64(ref_->stmt(), param_index + 1, val);
DCHECK_EQ(sqlite_result_code, SQLITE_OK);
}
void Statement::BindDouble(int param_index, double val) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
#if DCHECK_IS_ON()
DCHECK(!run_called_) << __func__ << " must not be called after Run()";
DCHECK(!step_called_) << __func__ << " must not be called after Step()";
#endif
if (!is_valid())
return;
WillBindParameter(param_index);
int sqlite_result_code =
sqlite3_bind_double(ref_->stmt(), param_index + 1, val);
DCHECK_EQ(sqlite_result_code, SQLITE_OK);
}
void Statement::BindTime(int param_index, base::Time val) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
#if DCHECK_IS_ON()
DCHECK(!run_called_) << __func__ << " must not be called after Run()";
DCHECK(!step_called_) << __func__ << " must not be called after Step()";
#endif
if (!is_valid())
return;
WillBindParameter(param_index);
int64_t int_value = TimeToSqlValue(val);
int sqlite_result_code =
sqlite3_bind_int64(ref_->stmt(), param_index + 1, int_value);
DCHECK_EQ(sqlite_result_code, SQLITE_OK);
}
void Statement::BindTimeDelta(int param_index, base::TimeDelta delta) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
#if DCHECK_IS_ON()
DCHECK(!run_called_) << __func__ << " must not be called after Run()";
DCHECK(!step_called_) << __func__ << " must not be called after Step()";
#endif
if (!is_valid()) {
return;
}
WillBindParameter(param_index);
int64_t int_value = delta.InMicroseconds();
int sqlite_result_code =
sqlite3_bind_int64(ref_->stmt(), param_index + 1, int_value);
DCHECK_EQ(sqlite_result_code, SQLITE_OK);
}
void Statement::BindCString(int param_index, const char* val) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
#if DCHECK_IS_ON()
DCHECK(!run_called_) << __func__ << " must not be called after Run()";
DCHECK(!step_called_) << __func__ << " must not be called after Step()";
#endif
DCHECK(val);
if (!is_valid())
return;
WillBindParameter(param_index);
int sqlite_result_code = sqlite3_bind_text(ref_->stmt(), param_index + 1, val,
-1, SQLITE_TRANSIENT);
DCHECK_EQ(sqlite_result_code, SQLITE_OK);
}
void Statement::BindString(int param_index, std::string_view value) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
#if DCHECK_IS_ON()
DCHECK(!run_called_) << __func__ << " must not be called after Run()";
DCHECK(!step_called_) << __func__ << " must not be called after Step()";
#endif
if (!is_valid())
return;
WillBindParameter(param_index);
static constexpr char kEmptyPlaceholder[] = {0x00};
const char* data = (value.size() > 0) ? value.data() : kEmptyPlaceholder;
int sqlite_result_code = sqlite3_bind_text(
ref_->stmt(), param_index + 1, data, value.size(), SQLITE_TRANSIENT);
DCHECK_EQ(sqlite_result_code, SQLITE_OK);
}
void Statement::BindString16(int param_index, std::u16string_view value) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return BindString(param_index, base::UTF16ToUTF8(value));
}
void Statement::BindBlob(int param_index,
scoped_refptr<base::RefCountedMemory> blob) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
#if DCHECK_IS_ON()
DCHECK(!run_called_) << __func__ << " must not be called after Run()";
DCHECK(!step_called_) << __func__ << " must not be called after Step()";
#endif
if (!is_valid()) {
return;
}
WillBindParameter(param_index);
base::span<const uint8_t> value =
ref_->TakeBlobMemory(param_index, std::move(blob));
static constexpr uint8_t kEmptyPlaceholder[] = {0x00};
const uint8_t* data = (value.size() > 0) ? value.data() : kEmptyPlaceholder;
int sqlite_result_code = sqlite3_bind_blob(ref_->stmt(), param_index + 1,
data, value.size(), SQLITE_STATIC);
DCHECK_EQ(sqlite_result_code, SQLITE_OK);
}
void Statement::BindBlob(int param_index, std::string blob) {
BindBlob(param_index,
base::MakeRefCounted<base::RefCountedString>(std::move(blob)));
}
void Statement::BindBlob(int param_index, std::u16string blob) {
BindBlob(param_index,
base::MakeRefCounted<base::RefCountedString16>(std::move(blob)));
}
void Statement::BindBlob(int param_index, std::vector<uint8_t> blob) {
BindBlob(param_index,
base::MakeRefCounted<base::RefCountedBytes>(std::move(blob)));
}
void Statement::BindBlob(int param_index, base::span<const uint8_t> blob) {
BindBlob(param_index, base::MakeRefCounted<base::RefCountedBytes>(blob));
}
void Statement::BindBlobForStreaming(int param_index, uint64_t size) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!is_valid()) {
return;
}
CHECK_EQ(SQLITE_OK, sqlite3_bind_zeroblob(ref_->stmt(), param_index + 1,
base::checked_cast<int>(size)));
}
int Statement::ColumnCount() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!is_valid())
return 0;
return sqlite3_column_count(ref_->stmt());
}
static_assert(static_cast<int>(ColumnType::kInteger) == SQLITE_INTEGER,
"INTEGER mismatch");
static_assert(static_cast<int>(ColumnType::kFloat) == SQLITE_FLOAT,
"FLOAT mismatch");
static_assert(static_cast<int>(ColumnType::kText) == SQLITE_TEXT,
"TEXT mismatch");
static_assert(static_cast<int>(ColumnType::kBlob) == SQLITE_BLOB,
"BLOB mismatch");
static_assert(static_cast<int>(ColumnType::kNull) == SQLITE_NULL,
"NULL mismatch");
ColumnType Statement::GetColumnType(int col) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
CheckCanReadColumn(col);
return static_cast<enum ColumnType>(sqlite3_column_type(ref_->stmt(), col));
}
bool Statement::ColumnBool(int column_index) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return static_cast<bool>(ColumnInt64(column_index));
}
int Statement::ColumnInt(int column_index) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
CheckCanReadColumn(column_index);
return sqlite3_column_int(ref_->stmt(), column_index);
}
int64_t Statement::ColumnInt64(int column_index) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
CheckCanReadColumn(column_index);
return sqlite3_column_int64(ref_->stmt(), column_index);
}
double Statement::ColumnDouble(int column_index) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
CheckCanReadColumn(column_index);
return sqlite3_column_double(ref_->stmt(), column_index);
}
base::Time Statement::ColumnTime(int column_index) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
CheckCanReadColumn(column_index);
int64_t int_value = sqlite3_column_int64(ref_->stmt(), column_index);
return base::Time::FromDeltaSinceWindowsEpoch(base::Microseconds(int_value));
}
base::TimeDelta Statement::ColumnTimeDelta(int column_index) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
CheckCanReadColumn(column_index);
int64_t int_value = sqlite3_column_int64(ref_->stmt(), column_index);
return base::Microseconds(int_value);
}
std::string_view Statement::ColumnStringView(int column_index) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
CheckCanReadColumn(column_index);
const char* string_buffer = reinterpret_cast<const char*>(
sqlite3_column_text(ref_->stmt(), column_index));
int size = sqlite3_column_bytes(ref_->stmt(), column_index);
DCHECK(size == 0 || string_buffer != nullptr)
<< "sqlite3_column_text() returned a null buffer for a non-empty string";
return std::string_view(string_buffer, base::checked_cast<size_t>(size));
}
std::string Statement::ColumnString(int column_index) {
return std::string(ColumnStringView(column_index));
}
std::u16string Statement::ColumnString16(int column_index) {
return base::UTF8ToUTF16(ColumnStringView(column_index));
}
base::span<const uint8_t> Statement::ColumnBlob(int column_index) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
CheckCanReadColumn(column_index);
const void* result_buffer = sqlite3_column_blob(ref_->stmt(), column_index);
int result_size = sqlite3_column_bytes(ref_->stmt(), column_index);
DCHECK(result_size == 0 || result_buffer != nullptr)
<< "sqlite3_column_blob() returned a null buffer for a non-empty BLOB";
return UNSAFE_TODO(base::span(static_cast<const uint8_t*>(result_buffer),
base::checked_cast<size_t>(result_size)));
}
std::string Statement::ColumnBlobAsString(int column_index) {
return std::string(base::as_string_view(ColumnBlob(column_index)));
}
std::optional<std::u16string> Statement::ColumnBlobAsString16(
int column_index) {
base::span<const uint8_t> bytes = ColumnBlob(column_index);
if (bytes.size() % 2 != 0) {
return std::nullopt;
}
std::u16string result(bytes.size() / 2, 0);
base::as_writable_byte_span(result).copy_from_nonoverlapping(bytes);
return result;
}
std::vector<uint8_t> Statement::ColumnBlobAsVector(int column_index) {
base::span<const uint8_t> byte_span = ColumnBlob(column_index);
return std::vector<uint8_t>(byte_span.begin(), byte_span.end());
}
std::string Statement::GetSQLStatement() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return sqlite3_sql(ref_->stmt());
}
SqliteResultCode Statement::CheckSqliteResultCode(
SqliteResultCode sqlite_result_code) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
last_sqlite_result_code_ = sqlite_result_code;
if (!IsSqliteSuccessCode(sqlite_result_code) && ref_.get() &&
ref_->database()) {
auto sqlite_error_code = ToSqliteErrorCode(sqlite_result_code);
ref_->database()->OnSqliteError(sqlite_error_code, this, nullptr);
}
return sqlite_result_code;
}
}