diff --git a/Cargo.toml b/Cargo.toml index 8facfce..5c155e8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ edition.workspace = true description = "Entity CRUD and change tracking for SQL databases with SQLx" [workspace.package] -version = "0.3.5" +version = "0.3.6" edition = "2021" [dependencies] diff --git a/sqlx-record-derive/src/lib.rs b/sqlx-record-derive/src/lib.rs index 78b2e5f..966d9d3 100644 --- a/sqlx-record-derive/src/lib.rs +++ b/sqlx-record-derive/src/lib.rs @@ -367,68 +367,13 @@ fn generate_insert_impl( .filter(|f| *f != #pk_db_name) .collect(); - #[cfg(feature = "mysql")] - let upsert_stmt = { - let update_clause = non_pk_fields.iter() - .map(|f| format!("{} = VALUES({})", f, f)) - .collect::>() - .join(", "); - format!( - "INSERT INTO {}{}{} ({}) VALUES ({}) ON DUPLICATE KEY UPDATE {}", - #tq, #table_name, #tq, - vec![#(#db_names),*].join(", "), - placeholders, - update_clause - ) - }; - - #[cfg(feature = "postgres")] - let upsert_stmt = { - let update_clause = non_pk_fields.iter() - .map(|f| format!("{} = EXCLUDED.{}", f, f)) - .collect::>() - .join(", "); - format!( - "INSERT INTO {}{}{} ({}) VALUES ({}) ON CONFLICT ({}) DO UPDATE SET {}", - #tq, #table_name, #tq, - vec![#(#db_names),*].join(", "), - placeholders, - #pk_db_name, - update_clause - ) - }; - - #[cfg(feature = "sqlite")] - let upsert_stmt = { - let update_clause = non_pk_fields.iter() - .map(|f| format!("{} = excluded.{}", f, f)) - .collect::>() - .join(", "); - format!( - "INSERT INTO {}{}{} ({}) VALUES ({}) ON CONFLICT({}) DO UPDATE SET {}", - #tq, #table_name, #tq, - vec![#(#db_names),*].join(", "), - placeholders, - #pk_db_name, - update_clause - ) - }; - - #[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))] - let upsert_stmt = { - // Fallback to MySQL syntax - let update_clause = non_pk_fields.iter() - .map(|f| format!("{} = VALUES({})", f, f)) - .collect::>() - .join(", "); - format!( - "INSERT INTO {}{}{} ({}) VALUES ({}) ON DUPLICATE KEY UPDATE {}", - #tq, #table_name, #tq, - vec![#(#db_names),*].join(", "), - placeholders, - update_clause - ) - }; + let upsert_stmt = ::sqlx_record::prelude::build_upsert_stmt( + #table_name, + &[#(#db_names),*], + #pk_db_name, + &non_pk_fields, + &placeholders, + ); sqlx::query(&upsert_stmt) #(.bind(#bindings))* @@ -778,13 +723,7 @@ fn generate_get_impl( String::new() }; - // Index hints are MySQL-specific - #[cfg(feature = "mysql")] - let index_clause = index - .map(|idx| format!("USE INDEX ({})", idx)) - .unwrap_or_default(); - #[cfg(not(feature = "mysql"))] - let index_clause = { let _ = index; String::new() }; + let index_clause = ::sqlx_record::prelude::build_index_clause(index); //Filter order_by fields to only those managed let fields = Self::select_fields().into_iter().collect::<::std::collections::HashSet<_>>(); @@ -848,23 +787,8 @@ fn generate_get_impl( String::new() }; - // Index hints are MySQL-specific - #[cfg(feature = "mysql")] - let index_clause = index - .map(|idx| format!("USE INDEX ({})", idx)) - .unwrap_or_default(); - #[cfg(not(feature = "mysql"))] - let index_clause = { let _ = index; String::new() }; - - // Use database-appropriate COUNT syntax - #[cfg(feature = "postgres")] - let count_expr = format!("COUNT({})::BIGINT", #pk_db_field_name); - #[cfg(feature = "sqlite")] - let count_expr = format!("COUNT({})", #pk_db_field_name); - #[cfg(feature = "mysql")] - let count_expr = format!("CAST(COUNT({}) AS SIGNED)", #pk_db_field_name); - #[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))] - let count_expr = format!("COUNT({})", #pk_db_field_name); + let index_clause = ::sqlx_record::prelude::build_index_clause(index); + let count_expr = ::sqlx_record::prelude::build_count_expr(#pk_db_field_name); let query = format!( r#"SELECT {} FROM {}{}{} {} {}"#, @@ -955,13 +879,7 @@ fn generate_get_impl( String::new() }; - // Index hints are MySQL-specific - #[cfg(feature = "mysql")] - let index_clause = index - .map(|idx| format!("USE INDEX ({})", idx)) - .unwrap_or_default(); - #[cfg(not(feature = "mysql"))] - let index_clause = { let _ = index; String::new() }; + let index_clause = ::sqlx_record::prelude::build_index_clause(index); let query = format!( "SELECT DISTINCT {} FROM {}{}{} {} {}", @@ -1085,12 +1003,15 @@ fn generate_update_impl( quote! {} }; - // Auto-update updated_at timestamp + // Auto-update updated_at timestamp (only if not manually set) let updated_at_increment = if has_updated_at { quote! { - parts.push(format!("updated_at = {}", ::sqlx_record::prelude::placeholder(idx))); - values.push(::sqlx_record::prelude::Value::Int64(chrono::Utc::now().timestamp_millis())); - idx += 1; + // Only auto-set updated_at if not already set in form or via expression + if self.updated_at.is_none() && !self._exprs.contains_key("updated_at") { + parts.push(format!("updated_at = {}", ::sqlx_record::prelude::placeholder(idx))); + values.push(::sqlx_record::prelude::Value::Int64(chrono::Utc::now().timestamp_millis())); + idx += 1; + } } } else { quote! {} @@ -1349,6 +1270,7 @@ fn generate_diff_impl( pub fn to_update_form(&self) -> #update_form_name #ty_generics { #update_form_name { #(#field_idents: Some(self.#field_idents.clone()),)* + _exprs: std::collections::HashMap::new(), } } @@ -1483,6 +1405,52 @@ fn generate_diff_impl( Ok(()) } + + /// Update all records matching the filter conditions + /// Returns the number of affected rows + pub async fn update_by_filter<'a, E>( + executor: E, + filters: Vec<::sqlx_record::prelude::Filter<'a>>, + form: #update_form_name, + ) -> Result + where + E: sqlx::Executor<'a, Database=#db>, + { + use ::sqlx_record::prelude::{Filter, bind_values}; + + if filters.is_empty() { + // Require at least one filter to prevent accidental table-wide updates + return Err(sqlx::Error::Protocol( + "update_by_filter requires at least one filter to prevent accidental table-wide updates".to_string() + )); + } + + let (update_stmt, form_values) = form.update_stmt_with_values(); + if update_stmt.is_empty() { + return Ok(0); + } + + let form_param_count = form_values.len(); + let (where_conditions, filter_values) = Filter::build_where_clause_with_offset(&filters, form_param_count + 1); + + let query_str = format!( + r#"UPDATE {}{}{} SET {} WHERE {}"#, + #tq, Self::table_name(), #tq, + update_stmt, + where_conditions, + ); + + // Combine form values and filter values + let mut all_values = form_values; + all_values.extend(filter_values); + + let query = sqlx::query(&query_str); + let result = bind_values(query, &all_values) + .execute(executor) + .await?; + + Ok(result.rows_affected()) + } } } } diff --git a/src/filter.rs b/src/filter.rs index 54371bd..341d9d3 100644 --- a/src/filter.rs +++ b/src/filter.rs @@ -111,6 +111,121 @@ pub fn placeholder(index: usize) -> String { } } +/// Returns the table quote character for the current database +#[inline] +pub fn table_quote() -> &'static str { + #[cfg(feature = "mysql")] + { "`" } + #[cfg(feature = "postgres")] + { "\"" } + #[cfg(feature = "sqlite")] + { "\"" } + #[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))] + { "`" } +} + +/// Builds an index hint clause (MySQL-specific, empty for other databases) +#[inline] +pub fn build_index_clause(index: Option<&str>) -> String { + #[cfg(feature = "mysql")] + { + index.map(|idx| format!("USE INDEX ({})", idx)).unwrap_or_default() + } + #[cfg(not(feature = "mysql"))] + { + let _ = index; + String::new() + } +} + +/// Builds a COUNT expression appropriate for the database backend +#[inline] +pub fn build_count_expr(field: &str) -> String { + #[cfg(feature = "postgres")] + { + format!("COUNT({})::BIGINT", field) + } + #[cfg(feature = "sqlite")] + { + format!("COUNT({})", field) + } + #[cfg(feature = "mysql")] + { + format!("CAST(COUNT({}) AS SIGNED)", field) + } + #[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))] + { + format!("COUNT({})", field) + } +} + +/// Builds an upsert statement for the current database backend +pub fn build_upsert_stmt( + table_name: &str, + all_fields: &[&str], + pk_field: &str, + non_pk_fields: &[&str], + placeholders: &str, +) -> String { + let tq = table_quote(); + let fields_str = all_fields.join(", "); + + #[cfg(feature = "mysql")] + { + let _ = pk_field; // Not used in MySQL ON DUPLICATE KEY syntax + let update_clause = non_pk_fields + .iter() + .map(|f| format!("{} = VALUES({})", f, f)) + .collect::>() + .join(", "); + format!( + "INSERT INTO {}{}{} ({}) VALUES ({}) ON DUPLICATE KEY UPDATE {}", + tq, table_name, tq, fields_str, placeholders, update_clause + ) + } + + #[cfg(feature = "postgres")] + { + let update_clause = non_pk_fields + .iter() + .map(|f| format!("{} = EXCLUDED.{}", f, f)) + .collect::>() + .join(", "); + format!( + "INSERT INTO {}{}{} ({}) VALUES ({}) ON CONFLICT ({}) DO UPDATE SET {}", + tq, table_name, tq, fields_str, placeholders, pk_field, update_clause + ) + } + + #[cfg(feature = "sqlite")] + { + let update_clause = non_pk_fields + .iter() + .map(|f| format!("{} = excluded.{}", f, f)) + .collect::>() + .join(", "); + format!( + "INSERT INTO {}{}{} ({}) VALUES ({}) ON CONFLICT({}) DO UPDATE SET {}", + tq, table_name, tq, fields_str, placeholders, pk_field, update_clause + ) + } + + #[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))] + { + let _ = pk_field; // Not used in MySQL ON DUPLICATE KEY syntax + // Fallback to MySQL syntax + let update_clause = non_pk_fields + .iter() + .map(|f| format!("{} = VALUES({})", f, f)) + .collect::>() + .join(", "); + format!( + "INSERT INTO {}{}{} ({}) VALUES ({}) ON DUPLICATE KEY UPDATE {}", + tq, table_name, tq, fields_str, placeholders, update_clause + ) + } +} + impl Filter<'_> { /// Returns the number of bind parameters this filter will use pub fn param_count(&self) -> usize {