use sqlx::query::{Query, QueryAs, QueryScalar}; use sqlx::types::chrono::{NaiveDate, NaiveDateTime}; use crate::filter::placeholder; // Database type alias based on enabled feature #[cfg(feature = "mysql")] pub type DB = sqlx::MySql; #[cfg(feature = "postgres")] pub type DB = sqlx::Postgres; #[cfg(feature = "sqlite")] pub type DB = sqlx::Sqlite; // Arguments type alias (used for non-lifetime-sensitive contexts) #[cfg(feature = "mysql")] pub type Arguments = sqlx::mysql::MySqlArguments; #[cfg(feature = "postgres")] pub type Arguments = sqlx::postgres::PgArguments; #[cfg(feature = "sqlite")] pub type Arguments = sqlx::sqlite::SqliteArguments<'static>; // Lifetime-aware arguments type for SQLite #[cfg(feature = "sqlite")] pub type Arguments_<'q> = sqlx::sqlite::SqliteArguments<'q>; #[cfg(feature = "mysql")] pub type Arguments_<'q> = sqlx::mysql::MySqlArguments; #[cfg(feature = "postgres")] pub type Arguments_<'q> = sqlx::postgres::PgArguments; #[derive(Clone, Debug)] pub enum Value { Int8(i8), Uint8(u8), Int16(i16), Uint16(u16), Int32(i32), Uint32(u32), Int64(i64), Uint64(u64), VecU8(Vec), String(String), Bool(bool), Uuid(uuid::Uuid), NaiveDate(NaiveDate), NaiveDateTime(NaiveDateTime), } /// Expression for column updates beyond simple value assignment. /// Used with `eval_*` methods on UpdateForm. #[derive(Clone, Debug)] pub enum UpdateExpr { /// column = value (equivalent to with_* methods) Set(Value), /// column = column + value Add(Value), /// column = column - value Sub(Value), /// column = column * value Mul(Value), /// column = column / value Div(Value), /// column = column % value Mod(Value), /// column = CASE WHEN cond1 THEN val1 WHEN cond2 THEN val2 ... ELSE default END Case { /// Vec of (condition, value) pairs for WHEN branches branches: Vec<(crate::filter::Filter<'static>, Value)>, /// Default value for ELSE branch default: Value, }, /// column = CASE WHEN condition THEN column + value ELSE column END AddIf { condition: crate::filter::Filter<'static>, value: Value, }, /// column = CASE WHEN condition THEN column - value ELSE column END SubIf { condition: crate::filter::Filter<'static>, value: Value, }, /// column = COALESCE(column, value) Coalesce(Value), /// column = GREATEST(column, value) - MySQL/PostgreSQL only Greatest(Value), /// column = LEAST(column, value) - MySQL/PostgreSQL only Least(Value), /// Raw SQL expression escape hatch: column = {sql} /// Placeholders in sql should use `?` and will be replaced with proper placeholders Raw { sql: String, values: Vec, }, } impl UpdateExpr { /// Build SQL expression and collect values for binding. /// Returns (sql_fragment, values_to_bind) pub fn build_sql(&self, column: &str, start_idx: usize) -> (String, Vec) { use crate::filter::Filter; match self { UpdateExpr::Set(v) => { (placeholder(start_idx), vec![v.clone()]) } UpdateExpr::Add(v) => { (format!("{} + {}", column, placeholder(start_idx)), vec![v.clone()]) } UpdateExpr::Sub(v) => { (format!("{} - {}", column, placeholder(start_idx)), vec![v.clone()]) } UpdateExpr::Mul(v) => { (format!("{} * {}", column, placeholder(start_idx)), vec![v.clone()]) } UpdateExpr::Div(v) => { (format!("{} / {}", column, placeholder(start_idx)), vec![v.clone()]) } UpdateExpr::Mod(v) => { (format!("{} % {}", column, placeholder(start_idx)), vec![v.clone()]) } UpdateExpr::Case { branches, default } => { let mut sql_parts = vec!["CASE".to_string()]; let mut values = Vec::new(); let mut idx = start_idx; for (condition, value) in branches { let (cond_sql, cond_values) = Filter::build_where_clause_with_offset( &[condition.clone()], idx, ); idx += cond_values.len(); values.extend(cond_values); sql_parts.push(format!("WHEN {} THEN {}", cond_sql, placeholder(idx))); values.push(value.clone()); idx += 1; } sql_parts.push(format!("ELSE {} END", placeholder(idx))); values.push(default.clone()); (sql_parts.join(" "), values) } UpdateExpr::AddIf { condition, value } => { let (cond_sql, cond_values) = Filter::build_where_clause_with_offset( &[condition.clone()], start_idx, ); let mut values = cond_values; let val_idx = start_idx + values.len(); let sql = format!( "CASE WHEN {} THEN {} + {} ELSE {} END", cond_sql, column, placeholder(val_idx), column ); values.push(value.clone()); (sql, values) } UpdateExpr::SubIf { condition, value } => { let (cond_sql, cond_values) = Filter::build_where_clause_with_offset( &[condition.clone()], start_idx, ); let mut values = cond_values; let val_idx = start_idx + values.len(); let sql = format!( "CASE WHEN {} THEN {} - {} ELSE {} END", cond_sql, column, placeholder(val_idx), column ); values.push(value.clone()); (sql, values) } UpdateExpr::Coalesce(v) => { (format!("COALESCE({}, {})", column, placeholder(start_idx)), vec![v.clone()]) } UpdateExpr::Greatest(v) => { (format!("GREATEST({}, {})", column, placeholder(start_idx)), vec![v.clone()]) } UpdateExpr::Least(v) => { (format!("LEAST({}, {})", column, placeholder(start_idx)), vec![v.clone()]) } UpdateExpr::Raw { sql, values } => { // Replace ? placeholders with proper database placeholders let mut result_sql = String::new(); let mut placeholder_count = 0; for ch in sql.chars() { if ch == '?' { result_sql.push_str(&placeholder(start_idx + placeholder_count)); placeholder_count += 1; } else { result_sql.push(ch); } } (result_sql, values.clone()) } } } /// Returns the number of bind parameters this expression will use pub fn param_count(&self) -> usize { match self { UpdateExpr::Set(_) => 1, UpdateExpr::Add(_) => 1, UpdateExpr::Sub(_) => 1, UpdateExpr::Mul(_) => 1, UpdateExpr::Div(_) => 1, UpdateExpr::Mod(_) => 1, UpdateExpr::Case { branches, default: _ } => { branches.iter().map(|(f, _)| f.param_count() + 1).sum::() + 1 } UpdateExpr::AddIf { condition, value: _ } => condition.param_count() + 1, UpdateExpr::SubIf { condition, value: _ } => condition.param_count() + 1, UpdateExpr::Coalesce(_) => 1, UpdateExpr::Greatest(_) => 1, UpdateExpr::Least(_) => 1, UpdateExpr::Raw { sql: _, values } => values.len(), } } } #[deprecated(since = "0.1.0", note = "Please use Value instead")] pub type SqlValue = Value; // MySQL supports unsigned integers natively #[cfg(feature = "mysql")] macro_rules! bind_value { ($query:expr, $value: expr) => {{ let query = match $value { Value::Int8(v) => $query.bind(v), Value::Uint8(v) => $query.bind(v), Value::Int16(v) => $query.bind(v), Value::Uint16(v) => $query.bind(v), Value::Int32(v) => $query.bind(v), Value::Uint32(v) => $query.bind(v), Value::Int64(v) => $query.bind(v), Value::Uint64(v) => $query.bind(v), Value::VecU8(v) => $query.bind(v), Value::String(v) => $query.bind(v), Value::Bool(v) => $query.bind(v), Value::Uuid(v) => $query.bind(v), Value::NaiveDate(v) => $query.bind(v), Value::NaiveDateTime(v) => $query.bind(v), }; query }}; } // PostgreSQL and SQLite don't support unsigned integers - convert to signed #[cfg(any(feature = "postgres", feature = "sqlite"))] macro_rules! bind_value { ($query:expr, $value: expr) => {{ let query = match $value { Value::Int8(v) => $query.bind(v), Value::Uint8(v) => $query.bind(*v as i16), Value::Int16(v) => $query.bind(v), Value::Uint16(v) => $query.bind(*v as i32), Value::Int32(v) => $query.bind(v), Value::Uint32(v) => $query.bind(*v as i64), Value::Int64(v) => $query.bind(v), Value::Uint64(v) => $query.bind(*v as i64), Value::VecU8(v) => $query.bind(v), Value::String(v) => $query.bind(v), Value::Bool(v) => $query.bind(v), Value::Uuid(v) => $query.bind(v), Value::NaiveDate(v) => $query.bind(v), Value::NaiveDateTime(v) => $query.bind(v), }; query }}; } #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))] pub fn bind_values<'q>(query: Query<'q, DB, Arguments_<'q>>, values: &'q [Value]) -> Query<'q, DB, Arguments_<'q>> { let mut query = query; for value in values { query = bind_value!(query, value); } query } /// Bind a single owned Value to a query #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))] pub fn bind_value_owned<'q>(query: Query<'q, DB, Arguments_<'q>>, value: Value) -> Query<'q, DB, Arguments_<'q>> { match value { Value::Int8(v) => query.bind(v), Value::Int16(v) => query.bind(v), Value::Int32(v) => query.bind(v), Value::Int64(v) => query.bind(v), #[cfg(feature = "mysql")] Value::Uint8(v) => query.bind(v), #[cfg(feature = "mysql")] Value::Uint16(v) => query.bind(v), #[cfg(feature = "mysql")] Value::Uint32(v) => query.bind(v), #[cfg(feature = "mysql")] Value::Uint64(v) => query.bind(v), #[cfg(any(feature = "postgres", feature = "sqlite"))] Value::Uint8(v) => query.bind(v as i16), #[cfg(any(feature = "postgres", feature = "sqlite"))] Value::Uint16(v) => query.bind(v as i32), #[cfg(any(feature = "postgres", feature = "sqlite"))] Value::Uint32(v) => query.bind(v as i64), #[cfg(any(feature = "postgres", feature = "sqlite"))] Value::Uint64(v) => query.bind(v as i64), Value::VecU8(v) => query.bind(v), Value::String(v) => query.bind(v), Value::Bool(v) => query.bind(v), Value::Uuid(v) => query.bind(v), Value::NaiveDate(v) => query.bind(v), Value::NaiveDateTime(v) => query.bind(v), } } #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))] pub fn bind_as_values<'q, O>(query: QueryAs<'q, DB, O, Arguments_<'q>>, values: &'q [Value]) -> QueryAs<'q, DB, O, Arguments_<'q>> { values.into_iter().fold(query, |query, value| { bind_value!(query, value) }) } #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))] pub fn bind_scalar_values<'q, O>(query: QueryScalar<'q, DB, O, Arguments_<'q>>, values: &'q [Value]) -> QueryScalar<'q, DB, O, Arguments_<'q>> { let mut query = query; for value in values { query = bind_value!(query, value); } query } #[inline] pub fn query_fields(fields: Vec<&str>) -> String { fields.iter().filter_map(|e| e.split(" ").next()) .collect::>().join(", ") } // From implementations for owned values impl From for Value { fn from(value: String) -> Self { Value::String(value) } } impl From for Value { fn from(value: i8) -> Self { Value::Int8(value) } } impl From for Value { fn from(value: u8) -> Self { Value::Uint8(value) } } impl From for Value { fn from(value: i16) -> Self { Value::Int16(value) } } impl From for Value { fn from(value: u16) -> Self { Value::Uint16(value) } } impl From for Value { fn from(value: i32) -> Self { Value::Int32(value) } } impl From for Value { fn from(value: u32) -> Self { Value::Uint32(value) } } impl From for Value { fn from(value: i64) -> Self { Value::Int64(value) } } impl From for Value { fn from(value: u64) -> Self { Value::Uint64(value) } } impl From for Value { fn from(value: bool) -> Self { Value::Bool(value) } } impl From for Value { fn from(value: uuid::Uuid) -> Self { Value::Uuid(value) } } impl From> for Value { fn from(value: Vec) -> Self { Value::VecU8(value) } } impl From for Value { fn from(value: NaiveDate) -> Self { Value::NaiveDate(value) } } impl From for Value { fn from(value: NaiveDateTime) -> Self { Value::NaiveDateTime(value) } } // From implementations for references impl From<&str> for Value { fn from(value: &str) -> Self { Value::String(value.to_string()) } } impl From<&String> for Value { fn from(value: &String) -> Self { Value::String(value.clone()) } } impl From<&i8> for Value { fn from(value: &i8) -> Self { Value::Int8(*value) } } impl From<&u8> for Value { fn from(value: &u8) -> Self { Value::Uint8(*value) } } impl From<&i16> for Value { fn from(value: &i16) -> Self { Value::Int16(*value) } } impl From<&u16> for Value { fn from(value: &u16) -> Self { Value::Uint16(*value) } } impl From<&i32> for Value { fn from(value: &i32) -> Self { Value::Int32(*value) } } impl From<&u32> for Value { fn from(value: &u32) -> Self { Value::Uint32(*value) } } impl From<&i64> for Value { fn from(value: &i64) -> Self { Value::Int64(*value) } } impl From<&u64> for Value { fn from(value: &u64) -> Self { Value::Uint64(*value) } } impl From<&bool> for Value { fn from(value: &bool) -> Self { Value::Bool(*value) } } impl From<&uuid::Uuid> for Value { fn from(value: &uuid::Uuid) -> Self { Value::Uuid(*value) } } impl From<&NaiveDate> for Value { fn from(value: &NaiveDate) -> Self { Value::NaiveDate(*value) } } impl From<&NaiveDateTime> for Value { fn from(value: &NaiveDateTime) -> Self { Value::NaiveDateTime(*value) } } pub trait BindValues<'q> { type Output; fn bind_values(self, values: &'q [Value]) -> Self::Output; } #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))] impl<'q> BindValues<'q> for Query<'q, DB, Arguments_<'q>> { type Output = Query<'q, DB, Arguments_<'q>>; fn bind_values(self, values: &'q [Value]) -> Self::Output { let mut query = self; for value in values { query = bind_value!(query, value); } query } } #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))] impl<'q, O> BindValues<'q> for QueryAs<'q, DB, O, Arguments_<'q>> { type Output = QueryAs<'q, DB, O, Arguments_<'q>>; fn bind_values(self, values: &'q [Value]) -> Self::Output { values.into_iter().fold(self, |query, value| { bind_value!(query, value) }) } } #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))] impl<'q, O> BindValues<'q> for QueryScalar<'q, DB, O, Arguments_<'q>> { type Output = QueryScalar<'q, DB, O, Arguments_<'q>>; fn bind_values(self, values: &'q [Value]) -> Self::Output { let mut query = self; for value in values { query = bind_value!(query, value); } query } } #[macro_export] macro_rules! values { () => { vec![] }; ($x:expr) => { vec![<$crate::prelude::Value>::from($x)] }; ($($x:expr),+ $(,)?) => { vec![$(<$crate::prelude::Value>::from($x)),+] }; }