diff --git a/Cargo.toml b/Cargo.toml index 1ffdea2..3476c16 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ edition.workspace = true description = "Entity CRUD and change tracking for SQL databases with SQLx" [workspace.package] -version = "0.3.3" +version = "0.3.4" edition = "2021" [dependencies] @@ -28,6 +28,7 @@ members = [ [features] default = [] derive = ["dep:sqlx-record-derive"] +static-validation = ["sqlx-record-derive?/static-validation"] decimal = ["dep:rust_decimal", "sqlx/rust_decimal"] # Database backends - user must enable at least one diff --git a/sqlx-record-derive/Cargo.toml b/sqlx-record-derive/Cargo.toml index 64787aa..fd03f6a 100644 --- a/sqlx-record-derive/Cargo.toml +++ b/sqlx-record-derive/Cargo.toml @@ -13,6 +13,7 @@ futures = "0.3" [features] default = [] +static-validation = [] mysql = [] postgres = [] sqlite = [] diff --git a/sqlx-record-derive/src/lib.rs b/sqlx-record-derive/src/lib.rs index ebba4bd..858e01b 100644 --- a/sqlx-record-derive/src/lib.rs +++ b/sqlx-record-derive/src/lib.rs @@ -106,6 +106,14 @@ fn table_quote() -> &'static str { { "`" } } +/// Get compile-time placeholder for static-validation SQL +fn static_placeholder(index: usize) -> String { + #[cfg(feature = "postgres")] + { format!("${}", index) } + #[cfg(not(feature = "postgres"))] + { let _ = index; "?".to_string() } +} + fn derive_entity_internal(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); let name = &input.ident; @@ -565,41 +573,86 @@ fn generate_get_impl( let field_list = fields.iter().map(|f| f.db_name.clone()).collect::>(); - let get_by_impl = quote! { - pub async fn #get_by_func<'a, E>(executor: E, #pk_field: &#pk_type) -> Result, sqlx::Error> - where - E: sqlx::Executor<'a, Database=#db>, - { - let select_stmt = format!( - r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#, - vec![#(#field_list),*].join(","), - #tq, #table_name, #tq, #pk_db_field_name, - ::sqlx_record::prelude::placeholder(1) - ); - let result = sqlx::query_as::<_, Self>(&select_stmt) - .bind(#pk_field) - .fetch_optional(executor) - .await?; + // Check if static-validation feature is enabled at macro expansion time + let use_static_validation = cfg!(feature = "static-validation"); - Ok(result) + let get_by_impl = if use_static_validation { + // Static validation: use sqlx::query_as! with compile-time checked SQL + let select_stmt = format!( + r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#, + select_fields.clone().collect::>().join(", "), + tq, table_name, tq, pk_db_field_name, + static_placeholder(1) + ); + quote! { + pub async fn #get_by_func<'a, E>(executor: E, #pk_field: &#pk_type) -> Result, sqlx::Error> + where + E: sqlx::Executor<'a, Database=#db>, + { + let result = sqlx::query_as!( + Self, + #select_stmt, + #pk_field + ) + .fetch_optional(executor) + .await?; + + Ok(result) + } + + pub async fn get_by_primary_key<'a, E>(executor: E, #pk_field: &#pk_type) -> Result, sqlx::Error> + where + E: sqlx::Executor<'a, Database=#db>, + { + let result = sqlx::query_as!( + Self, + #select_stmt, + #pk_field + ) + .fetch_optional(executor) + .await?; + + Ok(result) + } } + } else { + // Runtime: use sqlx::query_as with dynamic SQL + quote! { + pub async fn #get_by_func<'a, E>(executor: E, #pk_field: &#pk_type) -> Result, sqlx::Error> + where + E: sqlx::Executor<'a, Database=#db>, + { + let select_stmt = format!( + r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#, + vec![#(#field_list),*].join(","), + #tq, #table_name, #tq, #pk_db_field_name, + ::sqlx_record::prelude::placeholder(1) + ); + let result = sqlx::query_as::<_, Self>(&select_stmt) + .bind(#pk_field) + .fetch_optional(executor) + .await?; - pub async fn get_by_primary_key<'a, E>(executor: E, #pk_field: &#pk_type) -> Result, sqlx::Error> - where - E: sqlx::Executor<'a, Database=#db>, - { - let select_stmt = format!( - r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#, - vec![#(#field_list),*].join(","), - #tq, #table_name, #tq, #pk_db_field_name, - ::sqlx_record::prelude::placeholder(1) - ); - let result = sqlx::query_as::<_, Self>(&select_stmt) - .bind(#pk_field) - .fetch_optional(executor) - .await?; + Ok(result) + } - Ok(result) + pub async fn get_by_primary_key<'a, E>(executor: E, #pk_field: &#pk_type) -> Result, sqlx::Error> + where + E: sqlx::Executor<'a, Database=#db>, + { + let select_stmt = format!( + r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#, + vec![#(#field_list),*].join(","), + #tq, #table_name, #tq, #pk_db_field_name, + ::sqlx_record::prelude::placeholder(1) + ); + let result = sqlx::query_as::<_, Self>(&select_stmt) + .bind(#pk_field) + .fetch_optional(executor) + .await?; + + Ok(result) + } } };