diff --git a/Cargo.toml b/Cargo.toml index 2be5095..1ffdea2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ edition.workspace = true description = "Entity CRUD and change tracking for SQL databases with SQLx" [workspace.package] -version = "0.3.2" +version = "0.3.3" edition = "2021" [dependencies] @@ -16,6 +16,7 @@ uuid = { version = "1", features = ["v4"] } chrono = "0.4" rand = "0.8" paste = "1.0" +rust_decimal = { version = "1", optional = true } [workspace] members = [ @@ -27,7 +28,7 @@ members = [ [features] default = [] derive = ["dep:sqlx-record-derive"] -static-validation = ["sqlx-record-derive?/static-validation"] +decimal = ["dep:rust_decimal", "sqlx/rust_decimal"] # Database backends - user must enable at least one mysql = ["sqlx/mysql", "sqlx-record-derive?/mysql"] diff --git a/sqlx-record-derive/Cargo.toml b/sqlx-record-derive/Cargo.toml index fd03f6a..64787aa 100644 --- a/sqlx-record-derive/Cargo.toml +++ b/sqlx-record-derive/Cargo.toml @@ -13,7 +13,6 @@ futures = "0.3" [features] default = [] -static-validation = [] mysql = [] postgres = [] sqlite = [] diff --git a/sqlx-record-derive/src/lib.rs b/sqlx-record-derive/src/lib.rs index 088ecd6..ebba4bd 100644 --- a/sqlx-record-derive/src/lib.rs +++ b/sqlx-record-derive/src/lib.rs @@ -563,85 +563,43 @@ fn generate_get_impl( quote! {} }; - // Check if static-validation feature is enabled at macro expansion time - let use_static_validation = cfg!(feature = "static-validation"); + let field_list = fields.iter().map(|f| f.db_name.clone()).collect::>(); - let get_by_impl = if use_static_validation { - let select_stmt = format!( - r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = $1"#, - select_fields.clone().collect::>().join(", "), - tq, table_name, tq, pk_db_field_name - ); - quote! { - pub async fn #get_by_func<'a, E>(executor: E, #pk_field: &#pk_type) -> Result, sqlx::Error> - where - E: sqlx::Executor<'a, Database=#db>, - { - let result = sqlx::query_as!( - Self, - #select_stmt, - #pk_field - ) - .fetch_optional(executor) - .await?; + let get_by_impl = quote! { + pub async fn #get_by_func<'a, E>(executor: E, #pk_field: &#pk_type) -> Result, sqlx::Error> + where + E: sqlx::Executor<'a, Database=#db>, + { + let select_stmt = format!( + r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#, + vec![#(#field_list),*].join(","), + #tq, #table_name, #tq, #pk_db_field_name, + ::sqlx_record::prelude::placeholder(1) + ); + let result = sqlx::query_as::<_, Self>(&select_stmt) + .bind(#pk_field) + .fetch_optional(executor) + .await?; - Ok(result) - } - - pub async fn get_by_primary_key<'a, E>(executor: E, #pk_field: &#pk_type) -> Result, sqlx::Error> - where - E: sqlx::Executor<'a, Database=#db>, - { - let result = sqlx::query_as!( - Self, - #select_stmt, - #pk_field - ) - .fetch_optional(executor) - .await?; - - Ok(result) - } + Ok(result) } - } else { - let field_list = fields.iter().map(|f| f.db_name.clone()).collect::>(); - quote! { - pub async fn #get_by_func<'a, E>(executor: E, #pk_field: &#pk_type) -> Result, sqlx::Error> - where - E: sqlx::Executor<'a, Database=#db>, - { - let select_stmt = format!( - r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#, - vec![#(#field_list),*].join(","), - #tq, #table_name, #tq, #pk_db_field_name, - ::sqlx_record::prelude::placeholder(1) - ); - let result = sqlx::query_as::<_, Self>(&select_stmt) - .bind(#pk_field) - .fetch_optional(executor) - .await?; + pub async fn get_by_primary_key<'a, E>(executor: E, #pk_field: &#pk_type) -> Result, sqlx::Error> + where + E: sqlx::Executor<'a, Database=#db>, + { + let select_stmt = format!( + r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#, + vec![#(#field_list),*].join(","), + #tq, #table_name, #tq, #pk_db_field_name, + ::sqlx_record::prelude::placeholder(1) + ); + let result = sqlx::query_as::<_, Self>(&select_stmt) + .bind(#pk_field) + .fetch_optional(executor) + .await?; - Ok(result) - } - - pub async fn get_by_primary_key<'a, E>(executor: E, #pk_field: &#pk_type) -> Result, sqlx::Error> - where - E: sqlx::Executor<'a, Database=#db>, - { - let select_stmt = format!( - r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#, - vec![#(#field_list),*].join(","), - #tq, #table_name, #tq, #pk_db_field_name, - ::sqlx_record::prelude::placeholder(1) - ); - let result = sqlx::query_as::<_, Self>(&select_stmt) - .bind(#pk_field) - .fetch_optional(executor) - .await?; - - Ok(result) - } + Ok(result) } }; diff --git a/src/value.rs b/src/value.rs index 30eaeac..7552c03 100644 --- a/src/value.rs +++ b/src/value.rs @@ -1,5 +1,5 @@ use sqlx::query::{Query, QueryAs, QueryScalar}; -use sqlx::types::chrono::{NaiveDate, NaiveDateTime}; +use sqlx::types::chrono::{NaiveDate, NaiveDateTime, NaiveTime}; use crate::filter::placeholder; // Database type alias based on enabled feature @@ -34,6 +34,7 @@ pub type Arguments_<'q> = sqlx::postgres::PgArguments; #[derive(Clone, Debug)] pub enum Value { + Null, Int8(i8), Uint8(u8), Int16(i16), @@ -42,12 +43,18 @@ pub enum Value { Uint32(u32), Int64(i64), Uint64(u64), + Float32(f32), + Float64(f64), VecU8(Vec), String(String), Bool(bool), Uuid(uuid::Uuid), NaiveDate(NaiveDate), NaiveDateTime(NaiveDateTime), + NaiveTime(NaiveTime), + Json(serde_json::Value), + #[cfg(feature = "decimal")] + Decimal(rust_decimal::Decimal), } /// Expression for column updates beyond simple value assignment. @@ -253,6 +260,7 @@ pub type SqlValue = Value; macro_rules! bind_value { ($query:expr, $value: expr) => {{ let query = match $value { + Value::Null => $query.bind(None::), Value::Int8(v) => $query.bind(v), Value::Uint8(v) => $query.bind(v), Value::Int16(v) => $query.bind(v), @@ -261,12 +269,18 @@ macro_rules! bind_value { Value::Uint32(v) => $query.bind(v), Value::Int64(v) => $query.bind(v), Value::Uint64(v) => $query.bind(v), + Value::Float32(v) => $query.bind(v), + Value::Float64(v) => $query.bind(v), Value::VecU8(v) => $query.bind(v), Value::String(v) => $query.bind(v), Value::Bool(v) => $query.bind(v), Value::Uuid(v) => $query.bind(v), Value::NaiveDate(v) => $query.bind(v), Value::NaiveDateTime(v) => $query.bind(v), + Value::NaiveTime(v) => $query.bind(v), + Value::Json(v) => $query.bind(v), + #[cfg(feature = "decimal")] + Value::Decimal(v) => $query.bind(v), }; query }}; @@ -277,6 +291,7 @@ macro_rules! bind_value { macro_rules! bind_value { ($query:expr, $value: expr) => {{ let query = match $value { + Value::Null => $query.bind(None::), Value::Int8(v) => $query.bind(v), Value::Uint8(v) => $query.bind(*v as i16), Value::Int16(v) => $query.bind(v), @@ -285,12 +300,18 @@ macro_rules! bind_value { Value::Uint32(v) => $query.bind(*v as i64), Value::Int64(v) => $query.bind(v), Value::Uint64(v) => $query.bind(*v as i64), + Value::Float32(v) => $query.bind(v), + Value::Float64(v) => $query.bind(v), Value::VecU8(v) => $query.bind(v), Value::String(v) => $query.bind(v), Value::Bool(v) => $query.bind(v), Value::Uuid(v) => $query.bind(v), Value::NaiveDate(v) => $query.bind(v), Value::NaiveDateTime(v) => $query.bind(v), + Value::NaiveTime(v) => $query.bind(v), + Value::Json(v) => $query.bind(v), + #[cfg(feature = "decimal")] + Value::Decimal(v) => $query.bind(v), }; query }}; @@ -309,10 +330,13 @@ pub fn bind_values<'q>(query: Query<'q, DB, Arguments_<'q>>, values: &'q [Value] #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))] pub fn bind_value_owned<'q>(query: Query<'q, DB, Arguments_<'q>>, value: Value) -> Query<'q, DB, Arguments_<'q>> { match value { + Value::Null => query.bind(None::), Value::Int8(v) => query.bind(v), Value::Int16(v) => query.bind(v), Value::Int32(v) => query.bind(v), Value::Int64(v) => query.bind(v), + Value::Float32(v) => query.bind(v), + Value::Float64(v) => query.bind(v), #[cfg(feature = "mysql")] Value::Uint8(v) => query.bind(v), #[cfg(feature = "mysql")] @@ -335,6 +359,10 @@ pub fn bind_value_owned<'q>(query: Query<'q, DB, Arguments_<'q>>, value: Value) Value::Uuid(v) => query.bind(v), Value::NaiveDate(v) => query.bind(v), Value::NaiveDateTime(v) => query.bind(v), + Value::NaiveTime(v) => query.bind(v), + Value::Json(v) => query.bind(v), + #[cfg(feature = "decimal")] + Value::Decimal(v) => query.bind(v), } } @@ -530,6 +558,79 @@ impl From<&NaiveDateTime> for Value { } } +// New type implementations +impl From for Value { + fn from(value: f32) -> Self { + Value::Float32(value) + } +} + +impl From<&f32> for Value { + fn from(value: &f32) -> Self { + Value::Float32(*value) + } +} + +impl From for Value { + fn from(value: f64) -> Self { + Value::Float64(value) + } +} + +impl From<&f64> for Value { + fn from(value: &f64) -> Self { + Value::Float64(*value) + } +} + +impl From for Value { + fn from(value: NaiveTime) -> Self { + Value::NaiveTime(value) + } +} + +impl From<&NaiveTime> for Value { + fn from(value: &NaiveTime) -> Self { + Value::NaiveTime(*value) + } +} + +impl From for Value { + fn from(value: serde_json::Value) -> Self { + Value::Json(value) + } +} + +impl From<&serde_json::Value> for Value { + fn from(value: &serde_json::Value) -> Self { + Value::Json(value.clone()) + } +} + +#[cfg(feature = "decimal")] +impl From for Value { + fn from(value: rust_decimal::Decimal) -> Self { + Value::Decimal(value) + } +} + +#[cfg(feature = "decimal")] +impl From<&rust_decimal::Decimal> for Value { + fn from(value: &rust_decimal::Decimal) -> Self { + Value::Decimal(*value) + } +} + +// Option implementations - convert None to Value::Null +impl> From> for Value { + fn from(value: Option) -> Self { + match value { + Some(v) => v.into(), + None => Value::Null, + } + } +} + pub trait BindValues<'q> { type Output;