Add update_by_filter for bulk updates by filter conditions
Usage:
User::update_by_filter(&pool, filters![("status", "pending")], form).await?;
- Requires at least one filter to prevent accidental table-wide updates
- Returns number of affected rows
- Binds form values first, then filter values
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
3815913821
commit
a1464d3f7c
|
|
@ -5,7 +5,7 @@ edition.workspace = true
|
|||
description = "Entity CRUD and change tracking for SQL databases with SQLx"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.3.5"
|
||||
version = "0.3.6"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
|
|
|
|||
|
|
@ -367,68 +367,13 @@ fn generate_insert_impl(
|
|||
.filter(|f| *f != #pk_db_name)
|
||||
.collect();
|
||||
|
||||
#[cfg(feature = "mysql")]
|
||||
let upsert_stmt = {
|
||||
let update_clause = non_pk_fields.iter()
|
||||
.map(|f| format!("{} = VALUES({})", f, f))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
format!(
|
||||
"INSERT INTO {}{}{} ({}) VALUES ({}) ON DUPLICATE KEY UPDATE {}",
|
||||
#tq, #table_name, #tq,
|
||||
vec![#(#db_names),*].join(", "),
|
||||
placeholders,
|
||||
update_clause
|
||||
)
|
||||
};
|
||||
|
||||
#[cfg(feature = "postgres")]
|
||||
let upsert_stmt = {
|
||||
let update_clause = non_pk_fields.iter()
|
||||
.map(|f| format!("{} = EXCLUDED.{}", f, f))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
format!(
|
||||
"INSERT INTO {}{}{} ({}) VALUES ({}) ON CONFLICT ({}) DO UPDATE SET {}",
|
||||
#tq, #table_name, #tq,
|
||||
vec![#(#db_names),*].join(", "),
|
||||
placeholders,
|
||||
let upsert_stmt = ::sqlx_record::prelude::build_upsert_stmt(
|
||||
#table_name,
|
||||
&[#(#db_names),*],
|
||||
#pk_db_name,
|
||||
update_clause
|
||||
)
|
||||
};
|
||||
|
||||
#[cfg(feature = "sqlite")]
|
||||
let upsert_stmt = {
|
||||
let update_clause = non_pk_fields.iter()
|
||||
.map(|f| format!("{} = excluded.{}", f, f))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
format!(
|
||||
"INSERT INTO {}{}{} ({}) VALUES ({}) ON CONFLICT({}) DO UPDATE SET {}",
|
||||
#tq, #table_name, #tq,
|
||||
vec![#(#db_names),*].join(", "),
|
||||
placeholders,
|
||||
#pk_db_name,
|
||||
update_clause
|
||||
)
|
||||
};
|
||||
|
||||
#[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))]
|
||||
let upsert_stmt = {
|
||||
// Fallback to MySQL syntax
|
||||
let update_clause = non_pk_fields.iter()
|
||||
.map(|f| format!("{} = VALUES({})", f, f))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
format!(
|
||||
"INSERT INTO {}{}{} ({}) VALUES ({}) ON DUPLICATE KEY UPDATE {}",
|
||||
#tq, #table_name, #tq,
|
||||
vec![#(#db_names),*].join(", "),
|
||||
placeholders,
|
||||
update_clause
|
||||
)
|
||||
};
|
||||
&non_pk_fields,
|
||||
&placeholders,
|
||||
);
|
||||
|
||||
sqlx::query(&upsert_stmt)
|
||||
#(.bind(#bindings))*
|
||||
|
|
@ -778,13 +723,7 @@ fn generate_get_impl(
|
|||
String::new()
|
||||
};
|
||||
|
||||
// Index hints are MySQL-specific
|
||||
#[cfg(feature = "mysql")]
|
||||
let index_clause = index
|
||||
.map(|idx| format!("USE INDEX ({})", idx))
|
||||
.unwrap_or_default();
|
||||
#[cfg(not(feature = "mysql"))]
|
||||
let index_clause = { let _ = index; String::new() };
|
||||
let index_clause = ::sqlx_record::prelude::build_index_clause(index);
|
||||
|
||||
//Filter order_by fields to only those managed
|
||||
let fields = Self::select_fields().into_iter().collect::<::std::collections::HashSet<_>>();
|
||||
|
|
@ -848,23 +787,8 @@ fn generate_get_impl(
|
|||
String::new()
|
||||
};
|
||||
|
||||
// Index hints are MySQL-specific
|
||||
#[cfg(feature = "mysql")]
|
||||
let index_clause = index
|
||||
.map(|idx| format!("USE INDEX ({})", idx))
|
||||
.unwrap_or_default();
|
||||
#[cfg(not(feature = "mysql"))]
|
||||
let index_clause = { let _ = index; String::new() };
|
||||
|
||||
// Use database-appropriate COUNT syntax
|
||||
#[cfg(feature = "postgres")]
|
||||
let count_expr = format!("COUNT({})::BIGINT", #pk_db_field_name);
|
||||
#[cfg(feature = "sqlite")]
|
||||
let count_expr = format!("COUNT({})", #pk_db_field_name);
|
||||
#[cfg(feature = "mysql")]
|
||||
let count_expr = format!("CAST(COUNT({}) AS SIGNED)", #pk_db_field_name);
|
||||
#[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))]
|
||||
let count_expr = format!("COUNT({})", #pk_db_field_name);
|
||||
let index_clause = ::sqlx_record::prelude::build_index_clause(index);
|
||||
let count_expr = ::sqlx_record::prelude::build_count_expr(#pk_db_field_name);
|
||||
|
||||
let query = format!(
|
||||
r#"SELECT {} FROM {}{}{} {} {}"#,
|
||||
|
|
@ -955,13 +879,7 @@ fn generate_get_impl(
|
|||
String::new()
|
||||
};
|
||||
|
||||
// Index hints are MySQL-specific
|
||||
#[cfg(feature = "mysql")]
|
||||
let index_clause = index
|
||||
.map(|idx| format!("USE INDEX ({})", idx))
|
||||
.unwrap_or_default();
|
||||
#[cfg(not(feature = "mysql"))]
|
||||
let index_clause = { let _ = index; String::new() };
|
||||
let index_clause = ::sqlx_record::prelude::build_index_clause(index);
|
||||
|
||||
let query = format!(
|
||||
"SELECT DISTINCT {} FROM {}{}{} {} {}",
|
||||
|
|
@ -1085,13 +1003,16 @@ fn generate_update_impl(
|
|||
quote! {}
|
||||
};
|
||||
|
||||
// Auto-update updated_at timestamp
|
||||
// Auto-update updated_at timestamp (only if not manually set)
|
||||
let updated_at_increment = if has_updated_at {
|
||||
quote! {
|
||||
// Only auto-set updated_at if not already set in form or via expression
|
||||
if self.updated_at.is_none() && !self._exprs.contains_key("updated_at") {
|
||||
parts.push(format!("updated_at = {}", ::sqlx_record::prelude::placeholder(idx)));
|
||||
values.push(::sqlx_record::prelude::Value::Int64(chrono::Utc::now().timestamp_millis()));
|
||||
idx += 1;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote! {}
|
||||
};
|
||||
|
|
@ -1349,6 +1270,7 @@ fn generate_diff_impl(
|
|||
pub fn to_update_form(&self) -> #update_form_name #ty_generics {
|
||||
#update_form_name {
|
||||
#(#field_idents: Some(self.#field_idents.clone()),)*
|
||||
_exprs: std::collections::HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1483,6 +1405,52 @@ fn generate_diff_impl(
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update all records matching the filter conditions
|
||||
/// Returns the number of affected rows
|
||||
pub async fn update_by_filter<'a, E>(
|
||||
executor: E,
|
||||
filters: Vec<::sqlx_record::prelude::Filter<'a>>,
|
||||
form: #update_form_name,
|
||||
) -> Result<u64, sqlx::Error>
|
||||
where
|
||||
E: sqlx::Executor<'a, Database=#db>,
|
||||
{
|
||||
use ::sqlx_record::prelude::{Filter, bind_values};
|
||||
|
||||
if filters.is_empty() {
|
||||
// Require at least one filter to prevent accidental table-wide updates
|
||||
return Err(sqlx::Error::Protocol(
|
||||
"update_by_filter requires at least one filter to prevent accidental table-wide updates".to_string()
|
||||
));
|
||||
}
|
||||
|
||||
let (update_stmt, form_values) = form.update_stmt_with_values();
|
||||
if update_stmt.is_empty() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
let form_param_count = form_values.len();
|
||||
let (where_conditions, filter_values) = Filter::build_where_clause_with_offset(&filters, form_param_count + 1);
|
||||
|
||||
let query_str = format!(
|
||||
r#"UPDATE {}{}{} SET {} WHERE {}"#,
|
||||
#tq, Self::table_name(), #tq,
|
||||
update_stmt,
|
||||
where_conditions,
|
||||
);
|
||||
|
||||
// Combine form values and filter values
|
||||
let mut all_values = form_values;
|
||||
all_values.extend(filter_values);
|
||||
|
||||
let query = sqlx::query(&query_str);
|
||||
let result = bind_values(query, &all_values)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
115
src/filter.rs
115
src/filter.rs
|
|
@ -111,6 +111,121 @@ pub fn placeholder(index: usize) -> String {
|
|||
}
|
||||
}
|
||||
|
||||
/// Returns the table quote character for the current database
|
||||
#[inline]
|
||||
pub fn table_quote() -> &'static str {
|
||||
#[cfg(feature = "mysql")]
|
||||
{ "`" }
|
||||
#[cfg(feature = "postgres")]
|
||||
{ "\"" }
|
||||
#[cfg(feature = "sqlite")]
|
||||
{ "\"" }
|
||||
#[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))]
|
||||
{ "`" }
|
||||
}
|
||||
|
||||
/// Builds an index hint clause (MySQL-specific, empty for other databases)
|
||||
#[inline]
|
||||
pub fn build_index_clause(index: Option<&str>) -> String {
|
||||
#[cfg(feature = "mysql")]
|
||||
{
|
||||
index.map(|idx| format!("USE INDEX ({})", idx)).unwrap_or_default()
|
||||
}
|
||||
#[cfg(not(feature = "mysql"))]
|
||||
{
|
||||
let _ = index;
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Builds a COUNT expression appropriate for the database backend
|
||||
#[inline]
|
||||
pub fn build_count_expr(field: &str) -> String {
|
||||
#[cfg(feature = "postgres")]
|
||||
{
|
||||
format!("COUNT({})::BIGINT", field)
|
||||
}
|
||||
#[cfg(feature = "sqlite")]
|
||||
{
|
||||
format!("COUNT({})", field)
|
||||
}
|
||||
#[cfg(feature = "mysql")]
|
||||
{
|
||||
format!("CAST(COUNT({}) AS SIGNED)", field)
|
||||
}
|
||||
#[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))]
|
||||
{
|
||||
format!("COUNT({})", field)
|
||||
}
|
||||
}
|
||||
|
||||
/// Builds an upsert statement for the current database backend
|
||||
pub fn build_upsert_stmt(
|
||||
table_name: &str,
|
||||
all_fields: &[&str],
|
||||
pk_field: &str,
|
||||
non_pk_fields: &[&str],
|
||||
placeholders: &str,
|
||||
) -> String {
|
||||
let tq = table_quote();
|
||||
let fields_str = all_fields.join(", ");
|
||||
|
||||
#[cfg(feature = "mysql")]
|
||||
{
|
||||
let _ = pk_field; // Not used in MySQL ON DUPLICATE KEY syntax
|
||||
let update_clause = non_pk_fields
|
||||
.iter()
|
||||
.map(|f| format!("{} = VALUES({})", f, f))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
format!(
|
||||
"INSERT INTO {}{}{} ({}) VALUES ({}) ON DUPLICATE KEY UPDATE {}",
|
||||
tq, table_name, tq, fields_str, placeholders, update_clause
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(feature = "postgres")]
|
||||
{
|
||||
let update_clause = non_pk_fields
|
||||
.iter()
|
||||
.map(|f| format!("{} = EXCLUDED.{}", f, f))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
format!(
|
||||
"INSERT INTO {}{}{} ({}) VALUES ({}) ON CONFLICT ({}) DO UPDATE SET {}",
|
||||
tq, table_name, tq, fields_str, placeholders, pk_field, update_clause
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(feature = "sqlite")]
|
||||
{
|
||||
let update_clause = non_pk_fields
|
||||
.iter()
|
||||
.map(|f| format!("{} = excluded.{}", f, f))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
format!(
|
||||
"INSERT INTO {}{}{} ({}) VALUES ({}) ON CONFLICT({}) DO UPDATE SET {}",
|
||||
tq, table_name, tq, fields_str, placeholders, pk_field, update_clause
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))]
|
||||
{
|
||||
let _ = pk_field; // Not used in MySQL ON DUPLICATE KEY syntax
|
||||
// Fallback to MySQL syntax
|
||||
let update_clause = non_pk_fields
|
||||
.iter()
|
||||
.map(|f| format!("{} = VALUES({})", f, f))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
format!(
|
||||
"INSERT INTO {}{}{} ({}) VALUES ({}) ON DUPLICATE KEY UPDATE {}",
|
||||
tq, table_name, tq, fields_str, placeholders, update_clause
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Filter<'_> {
|
||||
/// Returns the number of bind parameters this filter will use
|
||||
pub fn param_count(&self) -> usize {
|
||||
|
|
|
|||
Loading…
Reference in New Issue