Add update_by_filter for bulk updates by filter conditions

Usage:
  User::update_by_filter(&pool, filters![("status", "pending")], form).await?;

- Requires at least one filter to prevent accidental table-wide updates
- Returns number of affected rows
- Binds form values first, then filter values

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Michael Netshipise 2026-01-29 20:55:10 +02:00
parent 3815913821
commit a1464d3f7c
3 changed files with 181 additions and 98 deletions

View File

@ -5,7 +5,7 @@ edition.workspace = true
description = "Entity CRUD and change tracking for SQL databases with SQLx"
[workspace.package]
version = "0.3.5"
version = "0.3.6"
edition = "2021"
[dependencies]

View File

@ -367,68 +367,13 @@ fn generate_insert_impl(
.filter(|f| *f != #pk_db_name)
.collect();
#[cfg(feature = "mysql")]
let upsert_stmt = {
let update_clause = non_pk_fields.iter()
.map(|f| format!("{} = VALUES({})", f, f))
.collect::<Vec<_>>()
.join(", ");
format!(
"INSERT INTO {}{}{} ({}) VALUES ({}) ON DUPLICATE KEY UPDATE {}",
#tq, #table_name, #tq,
vec![#(#db_names),*].join(", "),
placeholders,
update_clause
)
};
#[cfg(feature = "postgres")]
let upsert_stmt = {
let update_clause = non_pk_fields.iter()
.map(|f| format!("{} = EXCLUDED.{}", f, f))
.collect::<Vec<_>>()
.join(", ");
format!(
"INSERT INTO {}{}{} ({}) VALUES ({}) ON CONFLICT ({}) DO UPDATE SET {}",
#tq, #table_name, #tq,
vec![#(#db_names),*].join(", "),
placeholders,
let upsert_stmt = ::sqlx_record::prelude::build_upsert_stmt(
#table_name,
&[#(#db_names),*],
#pk_db_name,
update_clause
)
};
#[cfg(feature = "sqlite")]
let upsert_stmt = {
let update_clause = non_pk_fields.iter()
.map(|f| format!("{} = excluded.{}", f, f))
.collect::<Vec<_>>()
.join(", ");
format!(
"INSERT INTO {}{}{} ({}) VALUES ({}) ON CONFLICT({}) DO UPDATE SET {}",
#tq, #table_name, #tq,
vec![#(#db_names),*].join(", "),
placeholders,
#pk_db_name,
update_clause
)
};
#[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))]
let upsert_stmt = {
// Fallback to MySQL syntax
let update_clause = non_pk_fields.iter()
.map(|f| format!("{} = VALUES({})", f, f))
.collect::<Vec<_>>()
.join(", ");
format!(
"INSERT INTO {}{}{} ({}) VALUES ({}) ON DUPLICATE KEY UPDATE {}",
#tq, #table_name, #tq,
vec![#(#db_names),*].join(", "),
placeholders,
update_clause
)
};
&non_pk_fields,
&placeholders,
);
sqlx::query(&upsert_stmt)
#(.bind(#bindings))*
@ -778,13 +723,7 @@ fn generate_get_impl(
String::new()
};
// Index hints are MySQL-specific
#[cfg(feature = "mysql")]
let index_clause = index
.map(|idx| format!("USE INDEX ({})", idx))
.unwrap_or_default();
#[cfg(not(feature = "mysql"))]
let index_clause = { let _ = index; String::new() };
let index_clause = ::sqlx_record::prelude::build_index_clause(index);
//Filter order_by fields to only those managed
let fields = Self::select_fields().into_iter().collect::<::std::collections::HashSet<_>>();
@ -848,23 +787,8 @@ fn generate_get_impl(
String::new()
};
// Index hints are MySQL-specific
#[cfg(feature = "mysql")]
let index_clause = index
.map(|idx| format!("USE INDEX ({})", idx))
.unwrap_or_default();
#[cfg(not(feature = "mysql"))]
let index_clause = { let _ = index; String::new() };
// Use database-appropriate COUNT syntax
#[cfg(feature = "postgres")]
let count_expr = format!("COUNT({})::BIGINT", #pk_db_field_name);
#[cfg(feature = "sqlite")]
let count_expr = format!("COUNT({})", #pk_db_field_name);
#[cfg(feature = "mysql")]
let count_expr = format!("CAST(COUNT({}) AS SIGNED)", #pk_db_field_name);
#[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))]
let count_expr = format!("COUNT({})", #pk_db_field_name);
let index_clause = ::sqlx_record::prelude::build_index_clause(index);
let count_expr = ::sqlx_record::prelude::build_count_expr(#pk_db_field_name);
let query = format!(
r#"SELECT {} FROM {}{}{} {} {}"#,
@ -955,13 +879,7 @@ fn generate_get_impl(
String::new()
};
// Index hints are MySQL-specific
#[cfg(feature = "mysql")]
let index_clause = index
.map(|idx| format!("USE INDEX ({})", idx))
.unwrap_or_default();
#[cfg(not(feature = "mysql"))]
let index_clause = { let _ = index; String::new() };
let index_clause = ::sqlx_record::prelude::build_index_clause(index);
let query = format!(
"SELECT DISTINCT {} FROM {}{}{} {} {}",
@ -1085,13 +1003,16 @@ fn generate_update_impl(
quote! {}
};
// Auto-update updated_at timestamp
// Auto-update updated_at timestamp (only if not manually set)
let updated_at_increment = if has_updated_at {
quote! {
// Only auto-set updated_at if not already set in form or via expression
if self.updated_at.is_none() && !self._exprs.contains_key("updated_at") {
parts.push(format!("updated_at = {}", ::sqlx_record::prelude::placeholder(idx)));
values.push(::sqlx_record::prelude::Value::Int64(chrono::Utc::now().timestamp_millis()));
idx += 1;
}
}
} else {
quote! {}
};
@ -1349,6 +1270,7 @@ fn generate_diff_impl(
pub fn to_update_form(&self) -> #update_form_name #ty_generics {
#update_form_name {
#(#field_idents: Some(self.#field_idents.clone()),)*
_exprs: std::collections::HashMap::new(),
}
}
@ -1483,6 +1405,52 @@ fn generate_diff_impl(
Ok(())
}
/// Update all records matching the filter conditions
/// Returns the number of affected rows
pub async fn update_by_filter<'a, E>(
executor: E,
filters: Vec<::sqlx_record::prelude::Filter<'a>>,
form: #update_form_name,
) -> Result<u64, sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
{
use ::sqlx_record::prelude::{Filter, bind_values};
if filters.is_empty() {
// Require at least one filter to prevent accidental table-wide updates
return Err(sqlx::Error::Protocol(
"update_by_filter requires at least one filter to prevent accidental table-wide updates".to_string()
));
}
let (update_stmt, form_values) = form.update_stmt_with_values();
if update_stmt.is_empty() {
return Ok(0);
}
let form_param_count = form_values.len();
let (where_conditions, filter_values) = Filter::build_where_clause_with_offset(&filters, form_param_count + 1);
let query_str = format!(
r#"UPDATE {}{}{} SET {} WHERE {}"#,
#tq, Self::table_name(), #tq,
update_stmt,
where_conditions,
);
// Combine form values and filter values
let mut all_values = form_values;
all_values.extend(filter_values);
let query = sqlx::query(&query_str);
let result = bind_values(query, &all_values)
.execute(executor)
.await?;
Ok(result.rows_affected())
}
}
}
}

View File

@ -111,6 +111,121 @@ pub fn placeholder(index: usize) -> String {
}
}
/// Returns the table quote character for the current database
#[inline]
pub fn table_quote() -> &'static str {
#[cfg(feature = "mysql")]
{ "`" }
#[cfg(feature = "postgres")]
{ "\"" }
#[cfg(feature = "sqlite")]
{ "\"" }
#[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))]
{ "`" }
}
/// Builds an index hint clause (MySQL-specific, empty for other databases)
#[inline]
pub fn build_index_clause(index: Option<&str>) -> String {
#[cfg(feature = "mysql")]
{
index.map(|idx| format!("USE INDEX ({})", idx)).unwrap_or_default()
}
#[cfg(not(feature = "mysql"))]
{
let _ = index;
String::new()
}
}
/// Builds a COUNT expression appropriate for the database backend
#[inline]
pub fn build_count_expr(field: &str) -> String {
#[cfg(feature = "postgres")]
{
format!("COUNT({})::BIGINT", field)
}
#[cfg(feature = "sqlite")]
{
format!("COUNT({})", field)
}
#[cfg(feature = "mysql")]
{
format!("CAST(COUNT({}) AS SIGNED)", field)
}
#[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))]
{
format!("COUNT({})", field)
}
}
/// Builds an upsert statement for the current database backend
pub fn build_upsert_stmt(
table_name: &str,
all_fields: &[&str],
pk_field: &str,
non_pk_fields: &[&str],
placeholders: &str,
) -> String {
let tq = table_quote();
let fields_str = all_fields.join(", ");
#[cfg(feature = "mysql")]
{
let _ = pk_field; // Not used in MySQL ON DUPLICATE KEY syntax
let update_clause = non_pk_fields
.iter()
.map(|f| format!("{} = VALUES({})", f, f))
.collect::<Vec<_>>()
.join(", ");
format!(
"INSERT INTO {}{}{} ({}) VALUES ({}) ON DUPLICATE KEY UPDATE {}",
tq, table_name, tq, fields_str, placeholders, update_clause
)
}
#[cfg(feature = "postgres")]
{
let update_clause = non_pk_fields
.iter()
.map(|f| format!("{} = EXCLUDED.{}", f, f))
.collect::<Vec<_>>()
.join(", ");
format!(
"INSERT INTO {}{}{} ({}) VALUES ({}) ON CONFLICT ({}) DO UPDATE SET {}",
tq, table_name, tq, fields_str, placeholders, pk_field, update_clause
)
}
#[cfg(feature = "sqlite")]
{
let update_clause = non_pk_fields
.iter()
.map(|f| format!("{} = excluded.{}", f, f))
.collect::<Vec<_>>()
.join(", ");
format!(
"INSERT INTO {}{}{} ({}) VALUES ({}) ON CONFLICT({}) DO UPDATE SET {}",
tq, table_name, tq, fields_str, placeholders, pk_field, update_clause
)
}
#[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))]
{
let _ = pk_field; // Not used in MySQL ON DUPLICATE KEY syntax
// Fallback to MySQL syntax
let update_clause = non_pk_fields
.iter()
.map(|f| format!("{} = VALUES({})", f, f))
.collect::<Vec<_>>()
.join(", ");
format!(
"INSERT INTO {}{}{} ({}) VALUES ({}) ON DUPLICATE KEY UPDATE {}",
tq, table_name, tq, fields_str, placeholders, update_clause
)
}
}
impl Filter<'_> {
/// Returns the number of bind parameters this filter will use
pub fn param_count(&self) -> usize {