Add update_by_filter for bulk updates by filter conditions
Usage:
User::update_by_filter(&pool, filters![("status", "pending")], form).await?;
- Requires at least one filter to prevent accidental table-wide updates
- Returns number of affected rows
- Binds form values first, then filter values
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
3815913821
commit
a1464d3f7c
|
|
@ -5,7 +5,7 @@ edition.workspace = true
|
||||||
description = "Entity CRUD and change tracking for SQL databases with SQLx"
|
description = "Entity CRUD and change tracking for SQL databases with SQLx"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.3.5"
|
version = "0.3.6"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
|
|
||||||
|
|
@ -367,68 +367,13 @@ fn generate_insert_impl(
|
||||||
.filter(|f| *f != #pk_db_name)
|
.filter(|f| *f != #pk_db_name)
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
#[cfg(feature = "mysql")]
|
let upsert_stmt = ::sqlx_record::prelude::build_upsert_stmt(
|
||||||
let upsert_stmt = {
|
#table_name,
|
||||||
let update_clause = non_pk_fields.iter()
|
&[#(#db_names),*],
|
||||||
.map(|f| format!("{} = VALUES({})", f, f))
|
#pk_db_name,
|
||||||
.collect::<Vec<_>>()
|
&non_pk_fields,
|
||||||
.join(", ");
|
&placeholders,
|
||||||
format!(
|
);
|
||||||
"INSERT INTO {}{}{} ({}) VALUES ({}) ON DUPLICATE KEY UPDATE {}",
|
|
||||||
#tq, #table_name, #tq,
|
|
||||||
vec![#(#db_names),*].join(", "),
|
|
||||||
placeholders,
|
|
||||||
update_clause
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg(feature = "postgres")]
|
|
||||||
let upsert_stmt = {
|
|
||||||
let update_clause = non_pk_fields.iter()
|
|
||||||
.map(|f| format!("{} = EXCLUDED.{}", f, f))
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join(", ");
|
|
||||||
format!(
|
|
||||||
"INSERT INTO {}{}{} ({}) VALUES ({}) ON CONFLICT ({}) DO UPDATE SET {}",
|
|
||||||
#tq, #table_name, #tq,
|
|
||||||
vec![#(#db_names),*].join(", "),
|
|
||||||
placeholders,
|
|
||||||
#pk_db_name,
|
|
||||||
update_clause
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg(feature = "sqlite")]
|
|
||||||
let upsert_stmt = {
|
|
||||||
let update_clause = non_pk_fields.iter()
|
|
||||||
.map(|f| format!("{} = excluded.{}", f, f))
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join(", ");
|
|
||||||
format!(
|
|
||||||
"INSERT INTO {}{}{} ({}) VALUES ({}) ON CONFLICT({}) DO UPDATE SET {}",
|
|
||||||
#tq, #table_name, #tq,
|
|
||||||
vec![#(#db_names),*].join(", "),
|
|
||||||
placeholders,
|
|
||||||
#pk_db_name,
|
|
||||||
update_clause
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))]
|
|
||||||
let upsert_stmt = {
|
|
||||||
// Fallback to MySQL syntax
|
|
||||||
let update_clause = non_pk_fields.iter()
|
|
||||||
.map(|f| format!("{} = VALUES({})", f, f))
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join(", ");
|
|
||||||
format!(
|
|
||||||
"INSERT INTO {}{}{} ({}) VALUES ({}) ON DUPLICATE KEY UPDATE {}",
|
|
||||||
#tq, #table_name, #tq,
|
|
||||||
vec![#(#db_names),*].join(", "),
|
|
||||||
placeholders,
|
|
||||||
update_clause
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
sqlx::query(&upsert_stmt)
|
sqlx::query(&upsert_stmt)
|
||||||
#(.bind(#bindings))*
|
#(.bind(#bindings))*
|
||||||
|
|
@ -778,13 +723,7 @@ fn generate_get_impl(
|
||||||
String::new()
|
String::new()
|
||||||
};
|
};
|
||||||
|
|
||||||
// Index hints are MySQL-specific
|
let index_clause = ::sqlx_record::prelude::build_index_clause(index);
|
||||||
#[cfg(feature = "mysql")]
|
|
||||||
let index_clause = index
|
|
||||||
.map(|idx| format!("USE INDEX ({})", idx))
|
|
||||||
.unwrap_or_default();
|
|
||||||
#[cfg(not(feature = "mysql"))]
|
|
||||||
let index_clause = { let _ = index; String::new() };
|
|
||||||
|
|
||||||
//Filter order_by fields to only those managed
|
//Filter order_by fields to only those managed
|
||||||
let fields = Self::select_fields().into_iter().collect::<::std::collections::HashSet<_>>();
|
let fields = Self::select_fields().into_iter().collect::<::std::collections::HashSet<_>>();
|
||||||
|
|
@ -848,23 +787,8 @@ fn generate_get_impl(
|
||||||
String::new()
|
String::new()
|
||||||
};
|
};
|
||||||
|
|
||||||
// Index hints are MySQL-specific
|
let index_clause = ::sqlx_record::prelude::build_index_clause(index);
|
||||||
#[cfg(feature = "mysql")]
|
let count_expr = ::sqlx_record::prelude::build_count_expr(#pk_db_field_name);
|
||||||
let index_clause = index
|
|
||||||
.map(|idx| format!("USE INDEX ({})", idx))
|
|
||||||
.unwrap_or_default();
|
|
||||||
#[cfg(not(feature = "mysql"))]
|
|
||||||
let index_clause = { let _ = index; String::new() };
|
|
||||||
|
|
||||||
// Use database-appropriate COUNT syntax
|
|
||||||
#[cfg(feature = "postgres")]
|
|
||||||
let count_expr = format!("COUNT({})::BIGINT", #pk_db_field_name);
|
|
||||||
#[cfg(feature = "sqlite")]
|
|
||||||
let count_expr = format!("COUNT({})", #pk_db_field_name);
|
|
||||||
#[cfg(feature = "mysql")]
|
|
||||||
let count_expr = format!("CAST(COUNT({}) AS SIGNED)", #pk_db_field_name);
|
|
||||||
#[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))]
|
|
||||||
let count_expr = format!("COUNT({})", #pk_db_field_name);
|
|
||||||
|
|
||||||
let query = format!(
|
let query = format!(
|
||||||
r#"SELECT {} FROM {}{}{} {} {}"#,
|
r#"SELECT {} FROM {}{}{} {} {}"#,
|
||||||
|
|
@ -955,13 +879,7 @@ fn generate_get_impl(
|
||||||
String::new()
|
String::new()
|
||||||
};
|
};
|
||||||
|
|
||||||
// Index hints are MySQL-specific
|
let index_clause = ::sqlx_record::prelude::build_index_clause(index);
|
||||||
#[cfg(feature = "mysql")]
|
|
||||||
let index_clause = index
|
|
||||||
.map(|idx| format!("USE INDEX ({})", idx))
|
|
||||||
.unwrap_or_default();
|
|
||||||
#[cfg(not(feature = "mysql"))]
|
|
||||||
let index_clause = { let _ = index; String::new() };
|
|
||||||
|
|
||||||
let query = format!(
|
let query = format!(
|
||||||
"SELECT DISTINCT {} FROM {}{}{} {} {}",
|
"SELECT DISTINCT {} FROM {}{}{} {} {}",
|
||||||
|
|
@ -1085,12 +1003,15 @@ fn generate_update_impl(
|
||||||
quote! {}
|
quote! {}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Auto-update updated_at timestamp
|
// Auto-update updated_at timestamp (only if not manually set)
|
||||||
let updated_at_increment = if has_updated_at {
|
let updated_at_increment = if has_updated_at {
|
||||||
quote! {
|
quote! {
|
||||||
parts.push(format!("updated_at = {}", ::sqlx_record::prelude::placeholder(idx)));
|
// Only auto-set updated_at if not already set in form or via expression
|
||||||
values.push(::sqlx_record::prelude::Value::Int64(chrono::Utc::now().timestamp_millis()));
|
if self.updated_at.is_none() && !self._exprs.contains_key("updated_at") {
|
||||||
idx += 1;
|
parts.push(format!("updated_at = {}", ::sqlx_record::prelude::placeholder(idx)));
|
||||||
|
values.push(::sqlx_record::prelude::Value::Int64(chrono::Utc::now().timestamp_millis()));
|
||||||
|
idx += 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
quote! {}
|
quote! {}
|
||||||
|
|
@ -1349,6 +1270,7 @@ fn generate_diff_impl(
|
||||||
pub fn to_update_form(&self) -> #update_form_name #ty_generics {
|
pub fn to_update_form(&self) -> #update_form_name #ty_generics {
|
||||||
#update_form_name {
|
#update_form_name {
|
||||||
#(#field_idents: Some(self.#field_idents.clone()),)*
|
#(#field_idents: Some(self.#field_idents.clone()),)*
|
||||||
|
_exprs: std::collections::HashMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1483,6 +1405,52 @@ fn generate_diff_impl(
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Update all records matching the filter conditions
|
||||||
|
/// Returns the number of affected rows
|
||||||
|
pub async fn update_by_filter<'a, E>(
|
||||||
|
executor: E,
|
||||||
|
filters: Vec<::sqlx_record::prelude::Filter<'a>>,
|
||||||
|
form: #update_form_name,
|
||||||
|
) -> Result<u64, sqlx::Error>
|
||||||
|
where
|
||||||
|
E: sqlx::Executor<'a, Database=#db>,
|
||||||
|
{
|
||||||
|
use ::sqlx_record::prelude::{Filter, bind_values};
|
||||||
|
|
||||||
|
if filters.is_empty() {
|
||||||
|
// Require at least one filter to prevent accidental table-wide updates
|
||||||
|
return Err(sqlx::Error::Protocol(
|
||||||
|
"update_by_filter requires at least one filter to prevent accidental table-wide updates".to_string()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let (update_stmt, form_values) = form.update_stmt_with_values();
|
||||||
|
if update_stmt.is_empty() {
|
||||||
|
return Ok(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let form_param_count = form_values.len();
|
||||||
|
let (where_conditions, filter_values) = Filter::build_where_clause_with_offset(&filters, form_param_count + 1);
|
||||||
|
|
||||||
|
let query_str = format!(
|
||||||
|
r#"UPDATE {}{}{} SET {} WHERE {}"#,
|
||||||
|
#tq, Self::table_name(), #tq,
|
||||||
|
update_stmt,
|
||||||
|
where_conditions,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Combine form values and filter values
|
||||||
|
let mut all_values = form_values;
|
||||||
|
all_values.extend(filter_values);
|
||||||
|
|
||||||
|
let query = sqlx::query(&query_str);
|
||||||
|
let result = bind_values(query, &all_values)
|
||||||
|
.execute(executor)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(result.rows_affected())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
115
src/filter.rs
115
src/filter.rs
|
|
@ -111,6 +111,121 @@ pub fn placeholder(index: usize) -> String {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the table quote character for the current database
|
||||||
|
#[inline]
|
||||||
|
pub fn table_quote() -> &'static str {
|
||||||
|
#[cfg(feature = "mysql")]
|
||||||
|
{ "`" }
|
||||||
|
#[cfg(feature = "postgres")]
|
||||||
|
{ "\"" }
|
||||||
|
#[cfg(feature = "sqlite")]
|
||||||
|
{ "\"" }
|
||||||
|
#[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))]
|
||||||
|
{ "`" }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds an index hint clause (MySQL-specific, empty for other databases)
|
||||||
|
#[inline]
|
||||||
|
pub fn build_index_clause(index: Option<&str>) -> String {
|
||||||
|
#[cfg(feature = "mysql")]
|
||||||
|
{
|
||||||
|
index.map(|idx| format!("USE INDEX ({})", idx)).unwrap_or_default()
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "mysql"))]
|
||||||
|
{
|
||||||
|
let _ = index;
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds a COUNT expression appropriate for the database backend
|
||||||
|
#[inline]
|
||||||
|
pub fn build_count_expr(field: &str) -> String {
|
||||||
|
#[cfg(feature = "postgres")]
|
||||||
|
{
|
||||||
|
format!("COUNT({})::BIGINT", field)
|
||||||
|
}
|
||||||
|
#[cfg(feature = "sqlite")]
|
||||||
|
{
|
||||||
|
format!("COUNT({})", field)
|
||||||
|
}
|
||||||
|
#[cfg(feature = "mysql")]
|
||||||
|
{
|
||||||
|
format!("CAST(COUNT({}) AS SIGNED)", field)
|
||||||
|
}
|
||||||
|
#[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))]
|
||||||
|
{
|
||||||
|
format!("COUNT({})", field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds an upsert statement for the current database backend
|
||||||
|
pub fn build_upsert_stmt(
|
||||||
|
table_name: &str,
|
||||||
|
all_fields: &[&str],
|
||||||
|
pk_field: &str,
|
||||||
|
non_pk_fields: &[&str],
|
||||||
|
placeholders: &str,
|
||||||
|
) -> String {
|
||||||
|
let tq = table_quote();
|
||||||
|
let fields_str = all_fields.join(", ");
|
||||||
|
|
||||||
|
#[cfg(feature = "mysql")]
|
||||||
|
{
|
||||||
|
let _ = pk_field; // Not used in MySQL ON DUPLICATE KEY syntax
|
||||||
|
let update_clause = non_pk_fields
|
||||||
|
.iter()
|
||||||
|
.map(|f| format!("{} = VALUES({})", f, f))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(", ");
|
||||||
|
format!(
|
||||||
|
"INSERT INTO {}{}{} ({}) VALUES ({}) ON DUPLICATE KEY UPDATE {}",
|
||||||
|
tq, table_name, tq, fields_str, placeholders, update_clause
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "postgres")]
|
||||||
|
{
|
||||||
|
let update_clause = non_pk_fields
|
||||||
|
.iter()
|
||||||
|
.map(|f| format!("{} = EXCLUDED.{}", f, f))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(", ");
|
||||||
|
format!(
|
||||||
|
"INSERT INTO {}{}{} ({}) VALUES ({}) ON CONFLICT ({}) DO UPDATE SET {}",
|
||||||
|
tq, table_name, tq, fields_str, placeholders, pk_field, update_clause
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "sqlite")]
|
||||||
|
{
|
||||||
|
let update_clause = non_pk_fields
|
||||||
|
.iter()
|
||||||
|
.map(|f| format!("{} = excluded.{}", f, f))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(", ");
|
||||||
|
format!(
|
||||||
|
"INSERT INTO {}{}{} ({}) VALUES ({}) ON CONFLICT({}) DO UPDATE SET {}",
|
||||||
|
tq, table_name, tq, fields_str, placeholders, pk_field, update_clause
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))]
|
||||||
|
{
|
||||||
|
let _ = pk_field; // Not used in MySQL ON DUPLICATE KEY syntax
|
||||||
|
// Fallback to MySQL syntax
|
||||||
|
let update_clause = non_pk_fields
|
||||||
|
.iter()
|
||||||
|
.map(|f| format!("{} = VALUES({})", f, f))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(", ");
|
||||||
|
format!(
|
||||||
|
"INSERT INTO {}{}{} ({}) VALUES ({}) ON DUPLICATE KEY UPDATE {}",
|
||||||
|
tq, table_name, tq, fields_str, placeholders, update_clause
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Filter<'_> {
|
impl Filter<'_> {
|
||||||
/// Returns the number of bind parameters this filter will use
|
/// Returns the number of bind parameters this filter will use
|
||||||
pub fn param_count(&self) -> usize {
|
pub fn param_count(&self) -> usize {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue