sqlx-record/src/value.rs

589 lines
17 KiB
Rust

use sqlx::query::{Query, QueryAs, QueryScalar};
use sqlx::types::chrono::{NaiveDate, NaiveDateTime};
use crate::filter::placeholder;
// Database type alias based on enabled feature
#[cfg(feature = "mysql")]
pub type DB = sqlx::MySql;
#[cfg(feature = "postgres")]
pub type DB = sqlx::Postgres;
#[cfg(feature = "sqlite")]
pub type DB = sqlx::Sqlite;
// Arguments type alias (used for non-lifetime-sensitive contexts)
#[cfg(feature = "mysql")]
pub type Arguments = sqlx::mysql::MySqlArguments;
#[cfg(feature = "postgres")]
pub type Arguments = sqlx::postgres::PgArguments;
#[cfg(feature = "sqlite")]
pub type Arguments = sqlx::sqlite::SqliteArguments<'static>;
// Lifetime-aware arguments type for SQLite
#[cfg(feature = "sqlite")]
pub type Arguments_<'q> = sqlx::sqlite::SqliteArguments<'q>;
#[cfg(feature = "mysql")]
pub type Arguments_<'q> = sqlx::mysql::MySqlArguments;
#[cfg(feature = "postgres")]
pub type Arguments_<'q> = sqlx::postgres::PgArguments;
#[derive(Clone, Debug)]
pub enum Value {
Int8(i8),
Uint8(u8),
Int16(i16),
Uint16(u16),
Int32(i32),
Uint32(u32),
Int64(i64),
Uint64(u64),
VecU8(Vec<u8>),
String(String),
Bool(bool),
Uuid(uuid::Uuid),
NaiveDate(NaiveDate),
NaiveDateTime(NaiveDateTime),
}
/// Expression for column updates beyond simple value assignment.
/// Used with `eval_*` methods on UpdateForm.
#[derive(Clone, Debug)]
pub enum UpdateExpr {
/// column = value (equivalent to with_* methods)
Set(Value),
/// column = column + value
Add(Value),
/// column = column - value
Sub(Value),
/// column = column * value
Mul(Value),
/// column = column / value
Div(Value),
/// column = column % value
Mod(Value),
/// column = CASE WHEN cond1 THEN val1 WHEN cond2 THEN val2 ... ELSE default END
Case {
/// Vec of (condition, value) pairs for WHEN branches
branches: Vec<(crate::filter::Filter<'static>, Value)>,
/// Default value for ELSE branch
default: Value,
},
/// column = CASE WHEN condition THEN column + value ELSE column END
AddIf {
condition: crate::filter::Filter<'static>,
value: Value,
},
/// column = CASE WHEN condition THEN column - value ELSE column END
SubIf {
condition: crate::filter::Filter<'static>,
value: Value,
},
/// column = COALESCE(column, value)
Coalesce(Value),
/// column = GREATEST(column, value) - MySQL/PostgreSQL only
Greatest(Value),
/// column = LEAST(column, value) - MySQL/PostgreSQL only
Least(Value),
/// Raw SQL expression escape hatch: column = {sql}
/// Placeholders in sql should use `?` and will be replaced with proper placeholders
Raw {
sql: String,
values: Vec<Value>,
},
}
impl UpdateExpr {
/// Build SQL expression and collect values for binding.
/// Returns (sql_fragment, values_to_bind)
pub fn build_sql(&self, column: &str, start_idx: usize) -> (String, Vec<Value>) {
use crate::filter::Filter;
match self {
UpdateExpr::Set(v) => {
(placeholder(start_idx), vec![v.clone()])
}
UpdateExpr::Add(v) => {
(format!("{} + {}", column, placeholder(start_idx)), vec![v.clone()])
}
UpdateExpr::Sub(v) => {
(format!("{} - {}", column, placeholder(start_idx)), vec![v.clone()])
}
UpdateExpr::Mul(v) => {
(format!("{} * {}", column, placeholder(start_idx)), vec![v.clone()])
}
UpdateExpr::Div(v) => {
(format!("{} / {}", column, placeholder(start_idx)), vec![v.clone()])
}
UpdateExpr::Mod(v) => {
(format!("{} % {}", column, placeholder(start_idx)), vec![v.clone()])
}
UpdateExpr::Case { branches, default } => {
let mut sql_parts = vec!["CASE".to_string()];
let mut values = Vec::new();
let mut idx = start_idx;
for (condition, value) in branches {
let (cond_sql, cond_values) = Filter::build_where_clause_with_offset(
&[condition.clone()],
idx,
);
idx += cond_values.len();
values.extend(cond_values);
sql_parts.push(format!("WHEN {} THEN {}", cond_sql, placeholder(idx)));
values.push(value.clone());
idx += 1;
}
sql_parts.push(format!("ELSE {} END", placeholder(idx)));
values.push(default.clone());
(sql_parts.join(" "), values)
}
UpdateExpr::AddIf { condition, value } => {
let (cond_sql, cond_values) = Filter::build_where_clause_with_offset(
&[condition.clone()],
start_idx,
);
let mut values = cond_values;
let val_idx = start_idx + values.len();
let sql = format!(
"CASE WHEN {} THEN {} + {} ELSE {} END",
cond_sql,
column,
placeholder(val_idx),
column
);
values.push(value.clone());
(sql, values)
}
UpdateExpr::SubIf { condition, value } => {
let (cond_sql, cond_values) = Filter::build_where_clause_with_offset(
&[condition.clone()],
start_idx,
);
let mut values = cond_values;
let val_idx = start_idx + values.len();
let sql = format!(
"CASE WHEN {} THEN {} - {} ELSE {} END",
cond_sql,
column,
placeholder(val_idx),
column
);
values.push(value.clone());
(sql, values)
}
UpdateExpr::Coalesce(v) => {
(format!("COALESCE({}, {})", column, placeholder(start_idx)), vec![v.clone()])
}
UpdateExpr::Greatest(v) => {
(format!("GREATEST({}, {})", column, placeholder(start_idx)), vec![v.clone()])
}
UpdateExpr::Least(v) => {
(format!("LEAST({}, {})", column, placeholder(start_idx)), vec![v.clone()])
}
UpdateExpr::Raw { sql, values } => {
// Replace ? placeholders with proper database placeholders
let mut result_sql = String::new();
let mut placeholder_count = 0;
for ch in sql.chars() {
if ch == '?' {
result_sql.push_str(&placeholder(start_idx + placeholder_count));
placeholder_count += 1;
} else {
result_sql.push(ch);
}
}
(result_sql, values.clone())
}
}
}
/// Returns the number of bind parameters this expression will use
pub fn param_count(&self) -> usize {
match self {
UpdateExpr::Set(_) => 1,
UpdateExpr::Add(_) => 1,
UpdateExpr::Sub(_) => 1,
UpdateExpr::Mul(_) => 1,
UpdateExpr::Div(_) => 1,
UpdateExpr::Mod(_) => 1,
UpdateExpr::Case { branches, default: _ } => {
branches.iter().map(|(f, _)| f.param_count() + 1).sum::<usize>() + 1
}
UpdateExpr::AddIf { condition, value: _ } => condition.param_count() + 1,
UpdateExpr::SubIf { condition, value: _ } => condition.param_count() + 1,
UpdateExpr::Coalesce(_) => 1,
UpdateExpr::Greatest(_) => 1,
UpdateExpr::Least(_) => 1,
UpdateExpr::Raw { sql: _, values } => values.len(),
}
}
}
#[deprecated(since = "0.1.0", note = "Please use Value instead")]
pub type SqlValue = Value;
// MySQL supports unsigned integers natively
#[cfg(feature = "mysql")]
macro_rules! bind_value {
($query:expr, $value: expr) => {{
let query = match $value {
Value::Int8(v) => $query.bind(v),
Value::Uint8(v) => $query.bind(v),
Value::Int16(v) => $query.bind(v),
Value::Uint16(v) => $query.bind(v),
Value::Int32(v) => $query.bind(v),
Value::Uint32(v) => $query.bind(v),
Value::Int64(v) => $query.bind(v),
Value::Uint64(v) => $query.bind(v),
Value::VecU8(v) => $query.bind(v),
Value::String(v) => $query.bind(v),
Value::Bool(v) => $query.bind(v),
Value::Uuid(v) => $query.bind(v),
Value::NaiveDate(v) => $query.bind(v),
Value::NaiveDateTime(v) => $query.bind(v),
};
query
}};
}
// PostgreSQL and SQLite don't support unsigned integers - convert to signed
#[cfg(any(feature = "postgres", feature = "sqlite"))]
macro_rules! bind_value {
($query:expr, $value: expr) => {{
let query = match $value {
Value::Int8(v) => $query.bind(v),
Value::Uint8(v) => $query.bind(*v as i16),
Value::Int16(v) => $query.bind(v),
Value::Uint16(v) => $query.bind(*v as i32),
Value::Int32(v) => $query.bind(v),
Value::Uint32(v) => $query.bind(*v as i64),
Value::Int64(v) => $query.bind(v),
Value::Uint64(v) => $query.bind(*v as i64),
Value::VecU8(v) => $query.bind(v),
Value::String(v) => $query.bind(v),
Value::Bool(v) => $query.bind(v),
Value::Uuid(v) => $query.bind(v),
Value::NaiveDate(v) => $query.bind(v),
Value::NaiveDateTime(v) => $query.bind(v),
};
query
}};
}
#[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))]
pub fn bind_values<'q>(query: Query<'q, DB, Arguments_<'q>>, values: &'q [Value]) -> Query<'q, DB, Arguments_<'q>> {
let mut query = query;
for value in values {
query = bind_value!(query, value);
}
query
}
/// Bind a single owned Value to a query
#[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))]
pub fn bind_value_owned<'q>(query: Query<'q, DB, Arguments_<'q>>, value: Value) -> Query<'q, DB, Arguments_<'q>> {
match value {
Value::Int8(v) => query.bind(v),
Value::Int16(v) => query.bind(v),
Value::Int32(v) => query.bind(v),
Value::Int64(v) => query.bind(v),
#[cfg(feature = "mysql")]
Value::Uint8(v) => query.bind(v),
#[cfg(feature = "mysql")]
Value::Uint16(v) => query.bind(v),
#[cfg(feature = "mysql")]
Value::Uint32(v) => query.bind(v),
#[cfg(feature = "mysql")]
Value::Uint64(v) => query.bind(v),
#[cfg(any(feature = "postgres", feature = "sqlite"))]
Value::Uint8(v) => query.bind(v as i16),
#[cfg(any(feature = "postgres", feature = "sqlite"))]
Value::Uint16(v) => query.bind(v as i32),
#[cfg(any(feature = "postgres", feature = "sqlite"))]
Value::Uint32(v) => query.bind(v as i64),
#[cfg(any(feature = "postgres", feature = "sqlite"))]
Value::Uint64(v) => query.bind(v as i64),
Value::VecU8(v) => query.bind(v),
Value::String(v) => query.bind(v),
Value::Bool(v) => query.bind(v),
Value::Uuid(v) => query.bind(v),
Value::NaiveDate(v) => query.bind(v),
Value::NaiveDateTime(v) => query.bind(v),
}
}
#[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))]
pub fn bind_as_values<'q, O>(query: QueryAs<'q, DB, O, Arguments_<'q>>, values: &'q [Value]) -> QueryAs<'q, DB, O, Arguments_<'q>> {
values.into_iter().fold(query, |query, value| {
bind_value!(query, value)
})
}
#[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))]
pub fn bind_scalar_values<'q, O>(query: QueryScalar<'q, DB, O, Arguments_<'q>>, values: &'q [Value]) -> QueryScalar<'q, DB, O, Arguments_<'q>> {
let mut query = query;
for value in values {
query = bind_value!(query, value);
}
query
}
#[inline]
pub fn query_fields(fields: Vec<&str>) -> String {
fields.iter().filter_map(|e| e.split(" ").next())
.collect::<Vec<_>>().join(", ")
}
// From implementations for owned values
impl From<String> for Value {
fn from(value: String) -> Self {
Value::String(value)
}
}
impl From<i8> for Value {
fn from(value: i8) -> Self {
Value::Int8(value)
}
}
impl From<u8> for Value {
fn from(value: u8) -> Self {
Value::Uint8(value)
}
}
impl From<i16> for Value {
fn from(value: i16) -> Self {
Value::Int16(value)
}
}
impl From<u16> for Value {
fn from(value: u16) -> Self {
Value::Uint16(value)
}
}
impl From<i32> for Value {
fn from(value: i32) -> Self {
Value::Int32(value)
}
}
impl From<u32> for Value {
fn from(value: u32) -> Self {
Value::Uint32(value)
}
}
impl From<i64> for Value {
fn from(value: i64) -> Self {
Value::Int64(value)
}
}
impl From<u64> for Value {
fn from(value: u64) -> Self {
Value::Uint64(value)
}
}
impl From<bool> for Value {
fn from(value: bool) -> Self {
Value::Bool(value)
}
}
impl From<uuid::Uuid> for Value {
fn from(value: uuid::Uuid) -> Self {
Value::Uuid(value)
}
}
impl From<Vec<u8>> for Value {
fn from(value: Vec<u8>) -> Self {
Value::VecU8(value)
}
}
impl From<NaiveDate> for Value {
fn from(value: NaiveDate) -> Self {
Value::NaiveDate(value)
}
}
impl From<NaiveDateTime> for Value {
fn from(value: NaiveDateTime) -> Self {
Value::NaiveDateTime(value)
}
}
// From implementations for references
impl From<&str> for Value {
fn from(value: &str) -> Self {
Value::String(value.to_string())
}
}
impl From<&String> for Value {
fn from(value: &String) -> Self {
Value::String(value.clone())
}
}
impl From<&i8> for Value {
fn from(value: &i8) -> Self {
Value::Int8(*value)
}
}
impl From<&u8> for Value {
fn from(value: &u8) -> Self {
Value::Uint8(*value)
}
}
impl From<&i16> for Value {
fn from(value: &i16) -> Self {
Value::Int16(*value)
}
}
impl From<&u16> for Value {
fn from(value: &u16) -> Self {
Value::Uint16(*value)
}
}
impl From<&i32> for Value {
fn from(value: &i32) -> Self {
Value::Int32(*value)
}
}
impl From<&u32> for Value {
fn from(value: &u32) -> Self {
Value::Uint32(*value)
}
}
impl From<&i64> for Value {
fn from(value: &i64) -> Self {
Value::Int64(*value)
}
}
impl From<&u64> for Value {
fn from(value: &u64) -> Self {
Value::Uint64(*value)
}
}
impl From<&bool> for Value {
fn from(value: &bool) -> Self {
Value::Bool(*value)
}
}
impl From<&uuid::Uuid> for Value {
fn from(value: &uuid::Uuid) -> Self {
Value::Uuid(*value)
}
}
impl From<&NaiveDate> for Value {
fn from(value: &NaiveDate) -> Self {
Value::NaiveDate(*value)
}
}
impl From<&NaiveDateTime> for Value {
fn from(value: &NaiveDateTime) -> Self {
Value::NaiveDateTime(*value)
}
}
pub trait BindValues<'q> {
type Output;
fn bind_values(self, values: &'q [Value]) -> Self::Output;
}
#[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))]
impl<'q> BindValues<'q> for Query<'q, DB, Arguments_<'q>> {
type Output = Query<'q, DB, Arguments_<'q>>;
fn bind_values(self, values: &'q [Value]) -> Self::Output {
let mut query = self;
for value in values {
query = bind_value!(query, value);
}
query
}
}
#[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))]
impl<'q, O> BindValues<'q> for QueryAs<'q, DB, O, Arguments_<'q>> {
type Output = QueryAs<'q, DB, O, Arguments_<'q>>;
fn bind_values(self, values: &'q [Value]) -> Self::Output {
values.into_iter().fold(self, |query, value| {
bind_value!(query, value)
})
}
}
#[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))]
impl<'q, O> BindValues<'q> for QueryScalar<'q, DB, O, Arguments_<'q>> {
type Output = QueryScalar<'q, DB, O, Arguments_<'q>>;
fn bind_values(self, values: &'q [Value]) -> Self::Output {
let mut query = self;
for value in values {
query = bind_value!(query, value);
}
query
}
}
#[macro_export]
macro_rules! values {
() => {
vec![]
};
($x:expr) => {
vec![<$crate::prelude::Value>::from($x)]
};
($($x:expr),+ $(,)?) => {
vec![$(<$crate::prelude::Value>::from($x)),+]
};
}