1590 lines
59 KiB
Rust
1590 lines
59 KiB
Rust
mod string_utils;
|
|
extern crate proc_macro;
|
|
|
|
use proc_macro::TokenStream;
|
|
use proc_macro2::{Ident, TokenStream as TokenStream2};
|
|
use quote::{quote, format_ident};
|
|
use syn::{parse_macro_input, DeriveInput, Data, LitStr, Type, ImplGenerics, TypeGenerics, WhereClause};
|
|
|
|
use crate::string_utils::{pluralize, to_snake_case};
|
|
|
|
|
|
struct EntityField {
|
|
ident: Ident,
|
|
db_name: String,
|
|
ty: Type,
|
|
needs_type_annotation: bool,
|
|
type_override: Option<String>,
|
|
is_primary_key: bool,
|
|
is_version_field: bool,
|
|
is_soft_delete: bool,
|
|
is_created_at: bool,
|
|
is_updated_at: bool,
|
|
}
|
|
|
|
/// Parse a string attribute that can be either:
|
|
/// - `#[attr("value")]` (Meta::List style)
|
|
/// - `#[attr = "value"]` (Meta::NameValue style)
|
|
fn parse_string_attr(attr: &syn::Attribute) -> Option<String> {
|
|
match &attr.meta {
|
|
syn::Meta::List(_) => {
|
|
// #[attr("value")] style
|
|
attr.parse_args::<LitStr>().ok().map(|lit| lit.value())
|
|
}
|
|
syn::Meta::NameValue(nv) => {
|
|
// #[attr = "value"] style
|
|
if let syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit), .. }) = &nv.value {
|
|
Some(lit.value())
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
// Support both Update and Entity attributes
|
|
#[proc_macro_derive(Update, attributes(rename, table_name, primary_key, field_type))]
|
|
pub fn derive_update(input: TokenStream) -> TokenStream {
|
|
derive_entity_internal(input)
|
|
}
|
|
|
|
#[proc_macro_derive(Entity, attributes(rename, table_name, primary_key, version, field_type, soft_delete, created_at, updated_at))]
|
|
pub fn derive_entity(input: TokenStream) -> TokenStream {
|
|
derive_entity_internal(input)
|
|
}
|
|
|
|
/// Generate database-specific types based on features
|
|
fn db_type() -> TokenStream2 {
|
|
#[cfg(feature = "postgres")]
|
|
{
|
|
quote! { sqlx::Postgres }
|
|
}
|
|
#[cfg(feature = "sqlite")]
|
|
{
|
|
return quote! { sqlx::Sqlite };
|
|
}
|
|
#[cfg(feature = "mysql")]
|
|
{
|
|
quote! { sqlx::MySql }
|
|
}
|
|
#[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))]
|
|
{
|
|
// Default to MySql for backwards compatibility
|
|
quote! { sqlx::MySql }
|
|
}
|
|
}
|
|
|
|
fn db_arguments() -> TokenStream2 {
|
|
#[cfg(feature = "postgres")]
|
|
{
|
|
quote! { sqlx::postgres::PgArguments }
|
|
}
|
|
#[cfg(feature = "sqlite")]
|
|
{
|
|
return quote! { sqlx::sqlite::SqliteArguments<'q> };
|
|
}
|
|
#[cfg(feature = "mysql")]
|
|
{
|
|
quote! { sqlx::mysql::MySqlArguments }
|
|
}
|
|
#[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))]
|
|
{
|
|
quote! { sqlx::mysql::MySqlArguments }
|
|
}
|
|
}
|
|
|
|
/// Get table quote character
|
|
fn table_quote() -> &'static str {
|
|
#[cfg(feature = "postgres")]
|
|
{ "\"" }
|
|
#[cfg(feature = "sqlite")]
|
|
{ return "\""; }
|
|
#[cfg(feature = "mysql")]
|
|
{ "`" }
|
|
#[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))]
|
|
{ "`" }
|
|
}
|
|
|
|
/// Get compile-time placeholder for static-check SQL
|
|
fn static_placeholder(index: usize) -> String {
|
|
#[cfg(feature = "postgres")]
|
|
{ format!("${}", index) }
|
|
#[cfg(not(feature = "postgres"))]
|
|
{ let _ = index; "?".to_string() }
|
|
}
|
|
|
|
fn derive_entity_internal(input: TokenStream) -> TokenStream {
|
|
let input = parse_macro_input!(input as DeriveInput);
|
|
let name = &input.ident;
|
|
let update_form_name = format_ident!("{}UpdateForm", name);
|
|
|
|
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
|
|
let table_name = get_table_name(&input);
|
|
let fields = parse_fields(&input);
|
|
|
|
let primary_key = fields.iter()
|
|
.find(|f| f.is_primary_key)
|
|
.or_else(|| fields.iter().find(|f| f.ident == "id" || f.ident == "code"))
|
|
.expect("Struct must have a primary key field, either explicitly specified or named 'id' or 'code'");
|
|
|
|
// Check for timestamp fields - either by attribute or by name
|
|
let has_created_at = fields.iter().any(|f| f.is_created_at) ||
|
|
fields.iter().any(|f| f.ident == "created_at" && matches!(&f.ty, Type::Path(p) if p.path.is_ident("i64")));
|
|
let has_updated_at = fields.iter().any(|f| f.is_updated_at) ||
|
|
fields.iter().any(|f| f.ident == "updated_at" && matches!(&f.ty, Type::Path(p) if p.path.is_ident("i64")));
|
|
|
|
let version_field = fields.iter()
|
|
.find(|f| f.is_version_field)
|
|
.or_else(|| fields.iter().find(|&f| is_version_field(f)));
|
|
|
|
// Find soft delete field (by attribute or by name convention)
|
|
// Convention: `is_active` (FALSE = deleted), `is_deleted`/`deleted` (TRUE = deleted)
|
|
let soft_delete_field = fields.iter()
|
|
.find(|f| f.is_soft_delete)
|
|
.or_else(|| fields.iter().find(|f| {
|
|
(f.ident == "is_active" || f.ident == "is_deleted" || f.ident == "deleted") &&
|
|
matches!(&f.ty, Type::Path(p) if p.path.is_ident("bool"))
|
|
}));
|
|
|
|
// Generate all implementations
|
|
let insert_impl = generate_insert_impl(&name, &table_name, primary_key, &fields, has_created_at, has_updated_at, &impl_generics, &ty_generics, &where_clause);
|
|
let get_impl = generate_get_impl(&name, &table_name, primary_key, version_field, soft_delete_field, &fields, &impl_generics, &ty_generics, &where_clause);
|
|
let update_impl = generate_update_impl(&name, &update_form_name, &table_name, &fields, primary_key, version_field, has_updated_at, &impl_generics, &ty_generics, &where_clause);
|
|
let diff_impl = generate_diff_impl(&name, &update_form_name, &fields, primary_key, version_field, &impl_generics, &ty_generics, &where_clause);
|
|
let delete_impl = generate_delete_impl(&name, &table_name, primary_key, &impl_generics, &ty_generics, &where_clause);
|
|
let soft_delete_impl = generate_soft_delete_impl(&name, &table_name, primary_key, soft_delete_field, &impl_generics, &ty_generics, &where_clause);
|
|
|
|
let pk_type = &primary_key.ty;
|
|
let pk_field_name = &primary_key.ident;
|
|
|
|
quote! {
|
|
#insert_impl
|
|
#get_impl
|
|
#update_impl
|
|
#diff_impl
|
|
#delete_impl
|
|
#soft_delete_impl
|
|
|
|
impl #impl_generics #name #ty_generics #where_clause {
|
|
pub const fn table_name() -> &'static str {
|
|
#table_name
|
|
}
|
|
|
|
pub fn entity_key(#pk_field_name: &#pk_type) -> String {
|
|
format!("/entities/{}/{}", #table_name, #pk_field_name)
|
|
}
|
|
|
|
pub fn entity_changes_table_name() -> String {
|
|
format!("entity_changes_{}", #table_name)
|
|
}
|
|
}
|
|
}.into()
|
|
}
|
|
fn get_table_name(input: &DeriveInput) -> String {
|
|
input.attrs.iter()
|
|
.find_map(|attr| {
|
|
if attr.path().is_ident("table_name") {
|
|
parse_string_attr(attr)
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
.unwrap_or_else(|| to_snake_case(&input.ident.to_string()))
|
|
}
|
|
|
|
fn parse_fields(input: &DeriveInput) -> Vec<EntityField> {
|
|
match &input.data {
|
|
Data::Struct(data_struct) => {
|
|
data_struct.fields.iter().map(|field| {
|
|
let ident = field.ident.as_ref().unwrap().clone();
|
|
let db_name = field.attrs.iter()
|
|
.find_map(|attr| {
|
|
if attr.path().is_ident("rename") {
|
|
parse_string_attr(attr)
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
.unwrap_or_else(|| ident.to_string());
|
|
|
|
let type_override = field.attrs.iter()
|
|
.find_map(|attr| {
|
|
if attr.path().is_ident("field_type") {
|
|
parse_string_attr(attr)
|
|
} else {
|
|
None
|
|
}
|
|
});
|
|
|
|
let needs_type_annotation = type_override.is_some() || {
|
|
matches!(&field.ty, syn::Type::Path(p) if {
|
|
let type_str = quote!(#p).to_string();
|
|
type_str.contains("Uuid") || type_str.contains("bool")
|
|
})
|
|
};
|
|
|
|
let is_primary_key = field.attrs.iter()
|
|
.any(|attr| attr.path().is_ident("primary_key"));
|
|
let is_version_field = field.attrs.iter()
|
|
.any(|attr| attr.path().is_ident("version"));
|
|
let is_soft_delete = field.attrs.iter()
|
|
.any(|attr| attr.path().is_ident("soft_delete"));
|
|
let is_created_at = field.attrs.iter()
|
|
.any(|attr| attr.path().is_ident("created_at"));
|
|
let is_updated_at = field.attrs.iter()
|
|
.any(|attr| attr.path().is_ident("updated_at"));
|
|
|
|
EntityField {
|
|
ident,
|
|
db_name,
|
|
ty: field.ty.clone(),
|
|
needs_type_annotation,
|
|
type_override,
|
|
is_primary_key,
|
|
is_version_field,
|
|
is_soft_delete,
|
|
is_created_at,
|
|
is_updated_at,
|
|
}
|
|
}).collect()
|
|
}
|
|
_ => panic!("Entity can only be derived for structs"),
|
|
}
|
|
}
|
|
|
|
fn is_version_field(f: &EntityField) -> bool {
|
|
f.ident == "version" && matches!(&f.ty, Type::Path(p) if p.path.is_ident("u64") ||
|
|
p.path.is_ident("u32") || p.path.is_ident("i64") || p.path.is_ident("i32"))
|
|
}
|
|
|
|
// Generate the insert implementation
|
|
fn generate_insert_impl(
|
|
name: &Ident,
|
|
table_name: &str,
|
|
primary_key: &EntityField,
|
|
fields: &[EntityField],
|
|
_has_created_at: bool,
|
|
_has_updated_at: bool,
|
|
impl_generics: &ImplGenerics,
|
|
ty_generics: &TypeGenerics,
|
|
where_clause: &Option<&WhereClause>,
|
|
) -> TokenStream2 {
|
|
let db_names: Vec<_> = fields.iter().map(|f| &f.db_name).collect();
|
|
let field_idents: Vec<_> = fields.iter().map(|f| &f.ident).collect();
|
|
let tq = table_quote();
|
|
let db = db_type();
|
|
let pk_db_name = &primary_key.db_name;
|
|
|
|
let bindings: Vec<_> = fields.iter().map(|f| {
|
|
let ident = &f.ident;
|
|
quote! { &self.#ident }
|
|
}).collect();
|
|
|
|
let pk_field = &primary_key.ident;
|
|
let pk_type = &primary_key.ty;
|
|
|
|
let field_count = db_names.len();
|
|
|
|
quote! {
|
|
impl #impl_generics #name #ty_generics #where_clause {
|
|
pub async fn insert<'a, E>(&self, executor: E) -> Result<#pk_type, sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db>,
|
|
{
|
|
let placeholders: String = (1..=#field_count)
|
|
.map(|i| ::sqlx_record::prelude::placeholder(i))
|
|
.collect::<Vec<_>>()
|
|
.join(", ");
|
|
|
|
let insert_stmt = format!(
|
|
"INSERT INTO {}{}{} ({}) VALUES ({})",
|
|
#tq, #table_name, #tq,
|
|
vec![#(#db_names),*].join(", "),
|
|
placeholders
|
|
);
|
|
|
|
let result = sqlx::query(&insert_stmt)
|
|
#(.bind(#bindings))*
|
|
.execute(executor)
|
|
.await?;
|
|
|
|
Ok(self.#pk_field.clone())
|
|
}
|
|
|
|
/// Insert multiple entities in a single statement
|
|
pub async fn insert_many<'a, E>(executor: E, entities: &[Self]) -> Result<Vec<#pk_type>, sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db>,
|
|
Self: Clone,
|
|
{
|
|
if entities.is_empty() {
|
|
return Ok(vec![]);
|
|
}
|
|
|
|
let field_count = #field_count;
|
|
let mut placeholders = Vec::with_capacity(entities.len());
|
|
let mut current_idx = 1usize;
|
|
|
|
for _ in entities {
|
|
let row_placeholders: String = (0..field_count)
|
|
.map(|_| {
|
|
let ph = ::sqlx_record::prelude::placeholder(current_idx);
|
|
current_idx += 1;
|
|
ph
|
|
})
|
|
.collect::<Vec<_>>()
|
|
.join(", ");
|
|
placeholders.push(format!("({})", row_placeholders));
|
|
}
|
|
|
|
let insert_stmt = format!(
|
|
"INSERT INTO {}{}{} ({}) VALUES {}",
|
|
#tq, #table_name, #tq,
|
|
vec![#(#db_names),*].join(", "),
|
|
placeholders.join(", ")
|
|
);
|
|
|
|
let mut query = sqlx::query(&insert_stmt);
|
|
for entity in entities {
|
|
#(query = query.bind(&entity.#field_idents);)*
|
|
}
|
|
|
|
query.execute(executor).await?;
|
|
|
|
Ok(entities.iter().map(|e| e.#pk_field.clone()).collect())
|
|
}
|
|
|
|
/// Insert or update on primary key conflict (upsert)
|
|
pub async fn upsert<'a, E>(&self, executor: E) -> Result<#pk_type, sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db>,
|
|
{
|
|
let placeholders: String = (1..=#field_count)
|
|
.map(|i| ::sqlx_record::prelude::placeholder(i))
|
|
.collect::<Vec<_>>()
|
|
.join(", ");
|
|
|
|
let non_pk_fields: Vec<&str> = vec![#(#db_names),*]
|
|
.into_iter()
|
|
.filter(|f| *f != #pk_db_name)
|
|
.collect();
|
|
|
|
let upsert_stmt = ::sqlx_record::prelude::build_upsert_stmt(
|
|
#table_name,
|
|
&[#(#db_names),*],
|
|
#pk_db_name,
|
|
&non_pk_fields,
|
|
&placeholders,
|
|
);
|
|
|
|
sqlx::query(&upsert_stmt)
|
|
#(.bind(#bindings))*
|
|
.execute(executor)
|
|
.await?;
|
|
|
|
Ok(self.#pk_field.clone())
|
|
}
|
|
|
|
/// Alias for upsert
|
|
pub async fn insert_or_update<'a, E>(&self, executor: E) -> Result<#pk_type, sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db>,
|
|
{
|
|
self.upsert(executor).await
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn get_type_string(field: &EntityField) -> String {
|
|
if field.needs_type_annotation {
|
|
if let Some(override_type) = &field.type_override {
|
|
override_type.clone()
|
|
} else {
|
|
let ty = &field.ty;
|
|
let type_str = quote!(#ty).to_string();
|
|
|
|
// Remove all whitespace
|
|
let clean_type = type_str.replace(" ", "");
|
|
|
|
if clean_type.starts_with("Option<") && clean_type.ends_with(">") {
|
|
// Extract inner type between < and >
|
|
clean_type[7..clean_type.len()-1].to_string()
|
|
} else {
|
|
type_str
|
|
}
|
|
}
|
|
} else {
|
|
String::new()
|
|
}
|
|
}
|
|
|
|
// Generate the get implementations
|
|
fn generate_get_impl(
|
|
name: &Ident,
|
|
table_name: &str,
|
|
primary_key: &EntityField,
|
|
version_field: Option<&EntityField>,
|
|
_soft_delete_field: Option<&EntityField>, // Reserved for future auto-filtering
|
|
fields: &[EntityField],
|
|
impl_generics: &ImplGenerics,
|
|
ty_generics: &TypeGenerics,
|
|
where_clause: &Option<&WhereClause>,
|
|
) -> TokenStream2 {
|
|
let select_fields = fields.iter().map(|f| {
|
|
if f.needs_type_annotation {
|
|
let type_str = get_type_string(f);
|
|
format!("{} as \"{}: {}\"", f.db_name, f.ident, type_str)
|
|
} else {
|
|
f.db_name.clone()
|
|
}
|
|
});
|
|
|
|
let pk_field_name = primary_key.ident.to_string();
|
|
let multi_pk_field_name = Some(pluralize(&pk_field_name))
|
|
.filter(|x| x != &pk_field_name)
|
|
.unwrap_or_else(|| format!("{}_list", pk_field_name));
|
|
|
|
let get_by_func = format_ident!("get_by_{}", pk_field_name);
|
|
let multi_get_by_func = format_ident!("get_by_{}", multi_pk_field_name);
|
|
let pk_field = &primary_key.ident;
|
|
let pk_type = &primary_key.ty;
|
|
let pk_field_name = primary_key.ident.to_string();
|
|
let pk_db_field_name = &primary_key.db_name;
|
|
let tq = table_quote();
|
|
let db = db_type();
|
|
|
|
let new_fields = select_fields.clone().collect::<Vec<_>>();
|
|
let select_fields_str = new_fields.iter()
|
|
.filter_map(|e| e.split(" ")
|
|
.next()).collect::<Vec<_>>().join(", ");
|
|
|
|
let select_field_list = select_fields.clone().collect::<Vec<_>>();
|
|
|
|
// Generate the get_version function if version_field exists
|
|
let get_version_impl = if let Some(vfield) = version_field {
|
|
let version_field_type = &vfield.ty;
|
|
let version_db_name = &vfield.db_name;
|
|
|
|
quote! {
|
|
pub async fn get_version<'a, E>(executor: E, #pk_field: &#pk_type) -> Result<Option<#version_field_type>, sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db>,
|
|
{
|
|
let query = format!(
|
|
r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#,
|
|
#version_db_name,
|
|
#tq, Self::table_name(), #tq,
|
|
#pk_db_field_name,
|
|
::sqlx_record::prelude::placeholder(1)
|
|
);
|
|
|
|
let result = sqlx::query_scalar(&query)
|
|
.bind(#pk_field)
|
|
.fetch_optional(executor)
|
|
.await?;
|
|
Ok(result)
|
|
}
|
|
|
|
pub async fn get_versions<'a, E>(executor: E, keys: &Vec<#pk_type>) -> Result<::std::collections::HashMap<#pk_type, #version_field_type>, sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db>,
|
|
{
|
|
if keys.is_empty() {
|
|
return Ok(::std::collections::HashMap::new());
|
|
}
|
|
use sqlx::Row;
|
|
|
|
let placeholders: String = (1..=keys.len())
|
|
.map(|i| ::sqlx_record::prelude::placeholder(i))
|
|
.collect::<Vec<_>>()
|
|
.join(",");
|
|
let query = format!(
|
|
r#"SELECT DISTINCT {}, {} FROM {}{}{} WHERE {} IN ({})"#,
|
|
#pk_db_field_name,
|
|
#version_db_name,
|
|
#tq, Self::table_name(), #tq,
|
|
#pk_db_field_name,
|
|
placeholders
|
|
);
|
|
|
|
let mut query_builder = sqlx::query(&query);
|
|
for key in keys {
|
|
query_builder = query_builder.bind(key);
|
|
}
|
|
|
|
let rows = query_builder
|
|
.fetch_all(executor)
|
|
.await?;
|
|
|
|
let mut result = ::std::collections::HashMap::with_capacity(rows.len());
|
|
for row in rows {
|
|
let key: #pk_type = row.try_get(0)?;
|
|
let version: #version_field_type = row.try_get(1)?;
|
|
result.insert(key, version);
|
|
}
|
|
|
|
Ok(result)
|
|
}
|
|
}
|
|
|
|
|
|
} else {
|
|
// If no version field, generate empty implementation
|
|
quote! {}
|
|
};
|
|
|
|
let field_list = fields.iter().map(|f| f.db_name.clone()).collect::<Vec<_>>();
|
|
|
|
// Check if static-check feature is enabled at macro expansion time
|
|
let use_static_validation = cfg!(feature = "static-check");
|
|
|
|
let get_by_impl = if use_static_validation {
|
|
// Static validation: use sqlx::query_as! with compile-time checked SQL
|
|
let select_stmt = format!(
|
|
r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#,
|
|
select_fields.clone().collect::<Vec<_>>().join(", "),
|
|
tq, table_name, tq, pk_db_field_name,
|
|
static_placeholder(1)
|
|
);
|
|
quote! {
|
|
pub async fn #get_by_func<'a, E>(executor: E, #pk_field: &#pk_type) -> Result<Option<Self>, sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db>,
|
|
{
|
|
let result = sqlx::query_as!(
|
|
Self,
|
|
#select_stmt,
|
|
#pk_field
|
|
)
|
|
.fetch_optional(executor)
|
|
.await?;
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
pub async fn get_by_primary_key<'a, E>(executor: E, #pk_field: &#pk_type) -> Result<Option<Self>, sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db>,
|
|
{
|
|
let result = sqlx::query_as!(
|
|
Self,
|
|
#select_stmt,
|
|
#pk_field
|
|
)
|
|
.fetch_optional(executor)
|
|
.await?;
|
|
|
|
Ok(result)
|
|
}
|
|
}
|
|
} else {
|
|
// Runtime: use sqlx::query_as with dynamic SQL
|
|
quote! {
|
|
pub async fn #get_by_func<'a, E>(executor: E, #pk_field: &#pk_type) -> Result<Option<Self>, sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db>,
|
|
{
|
|
let select_stmt = format!(
|
|
r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#,
|
|
vec![#(#field_list),*].join(","),
|
|
#tq, #table_name, #tq, #pk_db_field_name,
|
|
::sqlx_record::prelude::placeholder(1)
|
|
);
|
|
let result = sqlx::query_as::<_, Self>(&select_stmt)
|
|
.bind(#pk_field)
|
|
.fetch_optional(executor)
|
|
.await?;
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
pub async fn get_by_primary_key<'a, E>(executor: E, #pk_field: &#pk_type) -> Result<Option<Self>, sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db>,
|
|
{
|
|
let select_stmt = format!(
|
|
r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#,
|
|
vec![#(#field_list),*].join(","),
|
|
#tq, #table_name, #tq, #pk_db_field_name,
|
|
::sqlx_record::prelude::placeholder(1)
|
|
);
|
|
let result = sqlx::query_as::<_, Self>(&select_stmt)
|
|
.bind(#pk_field)
|
|
.fetch_optional(executor)
|
|
.await?;
|
|
|
|
Ok(result)
|
|
}
|
|
}
|
|
};
|
|
|
|
quote! {
|
|
|
|
impl #impl_generics #name #ty_generics #where_clause {
|
|
|
|
#get_by_impl
|
|
|
|
pub const fn primary_key_field() -> &'static str {
|
|
#pk_field_name
|
|
}
|
|
|
|
pub const fn primary_key_db_field() -> &'static str {
|
|
#pk_db_field_name
|
|
}
|
|
|
|
pub fn primary_key(&self) -> &#pk_type {
|
|
&self.#pk_field
|
|
}
|
|
|
|
pub fn select_fields() -> Vec<&'static str> {
|
|
vec![#(#select_field_list),*]
|
|
}
|
|
|
|
pub async fn #multi_get_by_func<'a, E>(executor: E, ids: &[#pk_type]) -> Result<Vec<Self>, sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db>,
|
|
{
|
|
if ids.is_empty() {
|
|
return Ok(vec![]);
|
|
}
|
|
|
|
let placeholders: String = (1..=ids.len())
|
|
.map(|i| ::sqlx_record::prelude::placeholder(i))
|
|
.collect::<Vec<_>>()
|
|
.join(",");
|
|
|
|
let query = format!(
|
|
r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} IN ({})"#,
|
|
#select_fields_str,
|
|
#tq, Self::table_name(), #tq,
|
|
#pk_db_field_name,
|
|
placeholders
|
|
);
|
|
|
|
let mut q = sqlx::query_as::<_, Self>(&query);
|
|
for id in ids {
|
|
q = q.bind(id);
|
|
}
|
|
|
|
q.fetch_all(executor).await
|
|
}
|
|
|
|
#get_version_impl
|
|
|
|
pub async fn find<'a, E>(
|
|
executor: E,
|
|
filters: Vec<::sqlx_record::prelude::Filter<'a>>,
|
|
index: Option<&str>,
|
|
) -> Result<Vec<Self>, sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db>,
|
|
{
|
|
Self::find_ordered_with_limit(executor, filters, index, Vec::new(), None).await
|
|
}
|
|
pub async fn find_ordered<'a, E>(
|
|
executor: E,
|
|
filters: Vec<::sqlx_record::prelude::Filter<'a>>,
|
|
index: Option<&str>,
|
|
order_by: Vec<(&str, bool)>,
|
|
) -> Result<Vec<Self>, sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db>,
|
|
{
|
|
Self::find_ordered_with_limit(executor, filters, index, order_by, None).await
|
|
}
|
|
|
|
pub async fn find_one<'a, E>(
|
|
executor: E,
|
|
filters: Vec<::sqlx_record::prelude::Filter<'a>>,
|
|
index: Option<&str>,
|
|
) -> Result<Option<Self>, sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db>,
|
|
{
|
|
let found = Self::find_ordered_with_limit(executor, filters, index, Vec::new(), Some((0, 1))).await?;
|
|
Ok(found.into_iter().next())
|
|
}
|
|
|
|
pub async fn find_ordered_with_limit<'a, E>(
|
|
executor: E,
|
|
filters: Vec<::sqlx_record::prelude::Filter<'a>>,
|
|
index: Option<&str>,
|
|
order_by: Vec<(&str, bool)>,
|
|
offset_limit: Option<(u32, u32)>,
|
|
) -> Result<Vec<Self>, sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db>,
|
|
{
|
|
use ::sqlx_record::prelude::{Filter, Value, bind_as_values};
|
|
|
|
let (where_conditions, values) = Filter::build_where_clause(&filters);
|
|
let where_clause = if !where_conditions.is_empty() {
|
|
format!("WHERE {}", where_conditions)
|
|
} else {
|
|
String::new()
|
|
};
|
|
|
|
let index_clause = ::sqlx_record::prelude::build_index_clause(index);
|
|
|
|
//Filter order_by fields to only those managed
|
|
let fields = Self::select_fields().into_iter().collect::<::std::collections::HashSet<_>>();
|
|
let order_by = order_by.iter()
|
|
.filter(|(field, _)| fields.contains(field))
|
|
.collect::<Vec<_>>();
|
|
|
|
let order_by_clause = if !order_by.is_empty() {
|
|
let order_by_str = order_by.iter()
|
|
.map(|(field, asc)| format!("{} {}", field, if *asc { "ASC" } else { "DESC" }))
|
|
.collect::<Vec<_>>()
|
|
.join(", ");
|
|
format!("ORDER BY {}", order_by_str)
|
|
} else {
|
|
String::new()
|
|
};
|
|
|
|
let query = format!(
|
|
r#"SELECT DISTINCT {} FROM {}{}{} {} {} {} {}"#,
|
|
#select_fields_str,
|
|
#tq, #table_name, #tq,
|
|
index_clause,
|
|
where_clause,
|
|
order_by_clause,
|
|
offset_limit.map(|(offset, limit)| format!("LIMIT {} OFFSET {}", limit, offset)).unwrap_or_default(),
|
|
);
|
|
|
|
let db_query = sqlx::query_as(&query);
|
|
|
|
// Bind values to the query
|
|
let results = match bind_as_values(db_query, &values)
|
|
.fetch_all(executor)
|
|
.await {
|
|
Ok(results) => results,
|
|
Err(err) => {
|
|
tracing::error!(r#"Error executing
|
|
Query: {}
|
|
Error => {:?}
|
|
"#, query, err);
|
|
return Err(err);
|
|
}
|
|
};
|
|
|
|
Ok(results)
|
|
}
|
|
|
|
pub async fn count<'a, E>(
|
|
executor: E,
|
|
filters: Vec<::sqlx_record::prelude::Filter<'a>>,
|
|
index: Option<&str>,
|
|
) -> Result<u64, sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db>,
|
|
{
|
|
use ::sqlx_record::prelude::{Filter, Value, bind_scalar_values};
|
|
|
|
let (where_conditions, values) = Filter::build_where_clause(&filters);
|
|
let where_clause = if !where_conditions.is_empty() {
|
|
format!("WHERE {}", where_conditions)
|
|
} else {
|
|
String::new()
|
|
};
|
|
|
|
let index_clause = ::sqlx_record::prelude::build_index_clause(index);
|
|
let count_expr = ::sqlx_record::prelude::build_count_expr(#pk_db_field_name);
|
|
|
|
let query = format!(
|
|
r#"SELECT {} FROM {}{}{} {} {}"#,
|
|
count_expr,
|
|
#tq, #table_name, #tq,
|
|
index_clause,
|
|
where_clause,
|
|
);
|
|
|
|
let db_query = sqlx::query_scalar::<_, i64>(&query);
|
|
|
|
// Bind values to the query
|
|
let count = match bind_scalar_values(db_query, &values)
|
|
.fetch_optional(executor)
|
|
.await {
|
|
Ok(count) => count.unwrap_or(0) as u64,
|
|
Err(err) => {
|
|
tracing::error!(r#"Error executing
|
|
Query: {}
|
|
Error => {:?}
|
|
"#, query, err);
|
|
return Err(err);
|
|
}
|
|
};
|
|
|
|
Ok(count)
|
|
}
|
|
|
|
/// Paginate results with total count
|
|
pub async fn paginate<'a, E>(
|
|
executor: E,
|
|
filters: Vec<::sqlx_record::prelude::Filter<'a>>,
|
|
index: Option<&str>,
|
|
order_by: Vec<(&str, bool)>,
|
|
page_request: ::sqlx_record::prelude::PageRequest,
|
|
) -> Result<::sqlx_record::prelude::Page<Self>, sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db> + Copy,
|
|
{
|
|
// Get total count first
|
|
let total_count = Self::count(executor, filters.clone(), index).await?;
|
|
|
|
// Get page items
|
|
let items = Self::find_ordered_with_limit(
|
|
executor,
|
|
filters,
|
|
index,
|
|
order_by,
|
|
Some((page_request.offset(), page_request.limit())),
|
|
).await?;
|
|
|
|
Ok(::sqlx_record::prelude::Page::new(
|
|
items,
|
|
total_count,
|
|
page_request.page,
|
|
page_request.page_size,
|
|
))
|
|
}
|
|
|
|
/// Select specific fields only (returns raw rows)
|
|
/// Use `sqlx::Row` trait to access fields: `row.try_get::<String, _>("name")?`
|
|
pub async fn find_partial<'a, E>(
|
|
executor: E,
|
|
select_fields: &[&str],
|
|
filters: Vec<::sqlx_record::prelude::Filter<'a>>,
|
|
index: Option<&str>,
|
|
) -> Result<Vec<<#db as sqlx::Database>::Row>, sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db>,
|
|
{
|
|
use ::sqlx_record::prelude::{Filter, bind_values};
|
|
|
|
// Validate fields exist
|
|
let valid_fields: ::std::collections::HashSet<_> = Self::select_fields().into_iter().collect();
|
|
let selected: Vec<_> = select_fields.iter()
|
|
.filter(|f| valid_fields.contains(*f))
|
|
.copied()
|
|
.collect();
|
|
|
|
if selected.is_empty() {
|
|
return Ok(vec![]);
|
|
}
|
|
|
|
let (where_conditions, values) = Filter::build_where_clause(&filters);
|
|
let where_clause = if !where_conditions.is_empty() {
|
|
format!("WHERE {}", where_conditions)
|
|
} else {
|
|
String::new()
|
|
};
|
|
|
|
let index_clause = ::sqlx_record::prelude::build_index_clause(index);
|
|
|
|
let query = format!(
|
|
"SELECT DISTINCT {} FROM {}{}{} {} {}",
|
|
selected.join(", "),
|
|
#tq, #table_name, #tq,
|
|
index_clause,
|
|
where_clause,
|
|
);
|
|
|
|
let db_query = sqlx::query(&query);
|
|
bind_values(db_query, &values)
|
|
.fetch_all(executor)
|
|
.await
|
|
}
|
|
}
|
|
}
|
|
}
|
|
/// Check if a type is a binary type (Vec<u8>)
|
|
fn is_binary_type(ty: &Type) -> bool {
|
|
if let Type::Path(type_path) = ty {
|
|
let path_str = quote!(#type_path).to_string().replace(" ", "");
|
|
path_str == "Vec<u8>" || path_str == "std::vec::Vec<u8>"
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
|
|
fn generate_update_impl(
|
|
name: &Ident,
|
|
update_form_name: &Ident,
|
|
table_name: &str,
|
|
fields: &[EntityField],
|
|
primary_key: &EntityField,
|
|
version_field: Option<&EntityField>,
|
|
has_updated_at: bool,
|
|
impl_generics: &ImplGenerics,
|
|
ty_generics: &TypeGenerics,
|
|
where_clause: &Option<&WhereClause>,
|
|
) -> TokenStream2 {
|
|
let update_fields: Vec<_> = fields.iter()
|
|
.filter(|f| f.ident != primary_key.ident && f.ident != "created_at" &&
|
|
version_field.as_ref().map(|vf| f.ident != vf.ident).unwrap_or(true))
|
|
.collect();
|
|
|
|
let field_idents: Vec<_> = update_fields.iter().map(|f| &f.ident).collect();
|
|
let field_types: Vec<_> = update_fields.iter().map(|f| &f.ty).collect();
|
|
let db_names: Vec<_> = update_fields.iter().map(|f| &f.db_name).collect();
|
|
let db = db_type();
|
|
let db_args = db_arguments();
|
|
|
|
let setter_methods: Vec<_> = update_fields.iter().map(|field| {
|
|
let method_name = format_ident!("set_{}", field.ident);
|
|
let field_type = &field.ty;
|
|
let field_ident = &field.ident;
|
|
quote! {
|
|
pub fn #method_name<T>(&mut self, value: T) -> ()
|
|
where
|
|
T: Into<#field_type>,
|
|
{
|
|
self.#field_ident = Some(value.into());
|
|
}
|
|
}
|
|
}).collect();
|
|
|
|
let builder_methods = update_fields.iter().map(|field| {
|
|
let method_name = format_ident!("with_{}", field.ident);
|
|
let field_type = &field.ty;
|
|
let field_ident = &field.ident;
|
|
quote! {
|
|
pub fn #method_name<T>(mut self, value: T) -> Self
|
|
where
|
|
T: Into<#field_type>,
|
|
{
|
|
self.#field_ident = Some(value.into());
|
|
self
|
|
}
|
|
}
|
|
});
|
|
|
|
// Generate eval_* methods for non-binary fields
|
|
let eval_methods: Vec<_> = update_fields.iter()
|
|
.filter(|f| !is_binary_type(&f.ty))
|
|
.map(|field| {
|
|
let method_name = format_ident!("eval_{}", field.ident);
|
|
let db_name = &field.db_name;
|
|
quote! {
|
|
/// Set field using an UpdateExpr for complex operations (arithmetic, CASE/WHEN, etc.)
|
|
/// Takes precedence over with_* if both are set.
|
|
pub fn #method_name(mut self, expr: ::sqlx_record::prelude::UpdateExpr) -> Self {
|
|
self._exprs.insert(#db_name, expr);
|
|
self
|
|
}
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
// Version increment - use CASE WHEN for cross-database compatibility
|
|
let version_increment = if let Some(vfield) = version_field {
|
|
let version_db_name = &vfield.db_name;
|
|
let version_type = &vfield.ty;
|
|
|
|
// Handle different integer types with appropriate max values for wrapping
|
|
let max_val = match version_type {
|
|
Type::Path(type_path) if type_path.path.is_ident("u32") => "4294967295",
|
|
Type::Path(type_path) if type_path.path.is_ident("u64") => "18446744073709551615",
|
|
Type::Path(type_path) if type_path.path.is_ident("i32") => "2147483647",
|
|
Type::Path(type_path) if type_path.path.is_ident("i64") => "9223372036854775807",
|
|
_ => "",
|
|
};
|
|
|
|
if max_val.is_empty() {
|
|
quote! {
|
|
parts.push(format!("{} = {} + 1", #version_db_name, #version_db_name));
|
|
}
|
|
} else {
|
|
quote! {
|
|
parts.push(format!("{} = CASE WHEN {} = {} THEN 0 ELSE {} + 1 END", #version_db_name, #version_db_name, #max_val, #version_db_name));
|
|
}
|
|
}
|
|
} else {
|
|
quote! {}
|
|
};
|
|
|
|
// Auto-update updated_at timestamp (only if not manually set)
|
|
let updated_at_increment = if has_updated_at {
|
|
quote! {
|
|
// Only auto-set updated_at if not already set in form or via expression
|
|
if self.updated_at.is_none() && !self._exprs.contains_key("updated_at") {
|
|
parts.push(format!("updated_at = {}", ::sqlx_record::prelude::placeholder(idx)));
|
|
values.push(::sqlx_record::prelude::Value::Int64(chrono::Utc::now().timestamp_millis()));
|
|
idx += 1;
|
|
}
|
|
}
|
|
} else {
|
|
quote! {}
|
|
};
|
|
|
|
quote! {
|
|
/// Update form with support for simple value updates and complex expressions
|
|
pub struct #update_form_name #ty_generics #where_clause {
|
|
#(pub #field_idents: Option<#field_types>,)*
|
|
/// Expression-based updates (eval_* methods). Takes precedence over with_* for same field.
|
|
pub _exprs: std::collections::HashMap<&'static str, ::sqlx_record::prelude::UpdateExpr>,
|
|
}
|
|
|
|
impl #impl_generics Default for #update_form_name #ty_generics #where_clause {
|
|
fn default() -> Self {
|
|
Self {
|
|
#(#field_idents: None,)*
|
|
_exprs: std::collections::HashMap::new(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl #impl_generics #name #ty_generics #where_clause {
|
|
pub fn update_form() -> #update_form_name {
|
|
#update_form_name::default()
|
|
}
|
|
}
|
|
|
|
impl #impl_generics #update_form_name #ty_generics #where_clause {
|
|
pub fn new() -> Self {
|
|
Self::default()
|
|
}
|
|
|
|
#(#setter_methods)*
|
|
|
|
#(#builder_methods)*
|
|
|
|
#(#eval_methods)*
|
|
|
|
/// Raw SQL expression escape hatch for any field (no bind parameters).
|
|
/// For expressions with parameters, use `raw_with_values()`.
|
|
pub fn raw(mut self, field: &'static str, sql: impl Into<String>) -> Self {
|
|
self._exprs.insert(field, ::sqlx_record::prelude::UpdateExpr::Raw {
|
|
sql: sql.into(),
|
|
values: vec![],
|
|
});
|
|
self
|
|
}
|
|
|
|
/// Raw SQL expression escape hatch with bind parameters.
|
|
/// Use `?` for placeholders - they will be converted to proper database placeholders.
|
|
pub fn raw_with_values(mut self, field: &'static str, sql: impl Into<String>, values: Vec<::sqlx_record::prelude::Value>) -> Self {
|
|
self._exprs.insert(field, ::sqlx_record::prelude::UpdateExpr::Raw {
|
|
sql: sql.into(),
|
|
values,
|
|
});
|
|
self
|
|
}
|
|
|
|
/// Generate UPDATE SET clause and collect values for binding.
|
|
/// Expression fields (eval_*) take precedence over simple value fields (with_*).
|
|
pub fn update_stmt_with_values(&self) -> (String, Vec<::sqlx_record::prelude::Value>) {
|
|
let mut parts = Vec::new();
|
|
let mut values = Vec::new();
|
|
let mut idx = 1usize;
|
|
|
|
#(
|
|
// Check if this field has an expression (takes precedence)
|
|
if let Some(expr) = self._exprs.get(#db_names) {
|
|
let (sql, expr_values) = expr.build_sql(#db_names, idx);
|
|
parts.push(format!("{} = {}", #db_names, sql));
|
|
idx += expr_values.len();
|
|
values.extend(expr_values);
|
|
} else if let Some(ref value) = self.#field_idents {
|
|
parts.push(format!("{} = {}", #db_names, ::sqlx_record::prelude::placeholder(idx)));
|
|
values.push(::sqlx_record::prelude::Value::from(value.clone()));
|
|
idx += 1;
|
|
}
|
|
)*
|
|
|
|
#version_increment
|
|
#updated_at_increment
|
|
|
|
(parts.join(", "), values)
|
|
}
|
|
|
|
/// Generate UPDATE SET clause (backward compatible, without values).
|
|
/// For new code, prefer `update_stmt_with_values()`.
|
|
pub fn update_stmt(&self) -> String {
|
|
self.update_stmt_with_values().0
|
|
}
|
|
|
|
/// Bind all form values to query in correct order.
|
|
/// Handles both simple values and expression values, respecting expression precedence.
|
|
/// Uses Value enum for proper type handling of Option<T> fields.
|
|
pub fn bind_all_values<'q>(&'q self, mut query: sqlx::query::Query<'q, #db, #db_args>)
|
|
-> sqlx::query::Query<'q, #db, #db_args>
|
|
{
|
|
// Use update_stmt_with_values to get properly converted values
|
|
// This handles nested Options (Option<Option<T>>) correctly
|
|
let (_, values) = self.update_stmt_with_values();
|
|
for value in values {
|
|
query = ::sqlx_record::prelude::bind_value_owned(query, value);
|
|
}
|
|
query
|
|
}
|
|
|
|
/// Legacy binding method - binds values through the Value enum for proper type handling.
|
|
/// For backward compatibility. New code should use bind_all_values().
|
|
pub fn bind_form_values<'q>(&'q self, mut query: sqlx::query::Query<'q, #db, #db_args>)
|
|
-> sqlx::query::Query<'q, #db, #db_args>
|
|
{
|
|
// Always use Value-based binding to properly handle Option<T> fields
|
|
// This ensures nested Options (Option<Option<T>>) are unwrapped correctly
|
|
let (_, values) = self.update_stmt_with_values();
|
|
for value in values {
|
|
query = ::sqlx_record::prelude::bind_value_owned(query, value);
|
|
}
|
|
query
|
|
}
|
|
|
|
/// Check if this form uses any expressions
|
|
pub fn has_expressions(&self) -> bool {
|
|
!self._exprs.is_empty()
|
|
}
|
|
|
|
/// Get the number of bind parameters this form will use.
|
|
pub fn param_count(&self) -> usize {
|
|
self.update_stmt_with_values().1.len()
|
|
}
|
|
|
|
pub const fn table_name() -> &'static str {
|
|
#table_name
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn generate_diff_impl(
|
|
name: &Ident,
|
|
update_form_name: &Ident,
|
|
fields: &[EntityField],
|
|
primary_key: &EntityField,
|
|
version_field: Option<&EntityField>,
|
|
impl_generics: &ImplGenerics,
|
|
ty_generics: &TypeGenerics,
|
|
where_clause: &Option<&WhereClause>,
|
|
) -> TokenStream2 {
|
|
let update_fields: Vec<_> = fields.iter()
|
|
.filter(|f| f.ident != primary_key.ident && f.ident != "created_at" &&
|
|
version_field.as_ref().map(|vf| f.ident != vf.ident).unwrap_or(true))
|
|
.collect();
|
|
|
|
let field_idents: Vec<_> = update_fields.iter().map(|f| &f.ident).collect();
|
|
let db_names: Vec<_> = update_fields.iter().map(|f| &f.db_name).collect();
|
|
let field_types: Vec<_> = update_fields.iter().map(|f| &f.ty).collect();
|
|
let bind_form_values_trait = format_ident!("{}BindFormValues", name);
|
|
let bind_form_values_func = format_ident!("{}", to_snake_case(&name.to_string()));
|
|
|
|
let pk_field = primary_key.ident.clone();
|
|
let pk_type = primary_key.ty.clone();
|
|
let pk_db_name = primary_key.db_name.clone();
|
|
|
|
let multi_pk_field= format_ident!("{}", pluralize(&pk_field.to_string()));
|
|
let update_by_func = format_ident!("update_by_{}", pk_field);
|
|
let multi_update_by_func = format_ident!("update_by_{}", multi_pk_field);
|
|
|
|
let db = db_type();
|
|
let db_args = db_arguments();
|
|
let tq = table_quote();
|
|
|
|
quote! {
|
|
impl #impl_generics #update_form_name #ty_generics #where_clause {
|
|
pub fn model_diff(&self, other: &#name #ty_generics) -> serde_json::Value {
|
|
let mut changes = serde_json::Map::new();
|
|
#(
|
|
if let Some(ref value) = &self.#field_idents {
|
|
if &other.#field_idents != value {
|
|
changes.insert(#db_names.to_string(), serde_json::json!(value));
|
|
}
|
|
}
|
|
)*
|
|
serde_json::Value::Object(changes)
|
|
}
|
|
|
|
pub fn diff_modify(&mut self, model: &#name #ty_generics) -> serde_json::Value {
|
|
let mut changes = serde_json::Map::new();
|
|
#(
|
|
if let Some(ref value) = self.#field_idents {
|
|
if &model.#field_idents == value {
|
|
self.#field_idents = None;
|
|
} else {
|
|
changes.insert(#db_names.to_string(), serde_json::json!(value));
|
|
}
|
|
}
|
|
)*
|
|
serde_json::Value::Object(changes)
|
|
}
|
|
|
|
pub async fn db_diff<'q, E>(&self, #pk_field: &#pk_type, executor: E) -> Result<serde_json::Value, sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'q, Database=#db>,
|
|
{
|
|
use sqlx::Row;
|
|
let fields_to_fetch: Vec<String> = vec![
|
|
#(
|
|
if self.#field_idents.is_some() {
|
|
#db_names.to_string()
|
|
} else {
|
|
String::new()
|
|
}
|
|
),*
|
|
].into_iter().filter(|s| !s.is_empty()).collect();
|
|
|
|
if fields_to_fetch.is_empty() {
|
|
return Ok(serde_json::json!({}));
|
|
}
|
|
|
|
let query = format!(
|
|
"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}",
|
|
fields_to_fetch.join(", "),
|
|
#tq, Self::table_name(), #tq, #pk_db_name,
|
|
::sqlx_record::prelude::placeholder(1)
|
|
);
|
|
|
|
let row = sqlx::query(&query)
|
|
.bind(#pk_field)
|
|
.fetch_one(executor)
|
|
.await?;
|
|
|
|
let mut changes = serde_json::Map::new();
|
|
#(
|
|
if let Some(value) = &self.#field_idents {
|
|
if fields_to_fetch.contains(&#db_names.to_string()) {
|
|
let db_value: #field_types = row.try_get(#db_names)?;
|
|
if *value != db_value {
|
|
changes.insert(#db_names.to_string(), serde_json::json!(value));
|
|
}
|
|
}
|
|
}
|
|
)*
|
|
Ok(serde_json::Value::Object(changes))
|
|
}
|
|
}
|
|
|
|
impl #impl_generics #name #ty_generics #where_clause {
|
|
pub fn to_update_form(&self) -> #update_form_name #ty_generics {
|
|
#update_form_name {
|
|
#(#field_idents: Some(self.#field_idents.clone()),)*
|
|
_exprs: std::collections::HashMap::new(),
|
|
}
|
|
}
|
|
|
|
pub fn initial_diff(&self) -> serde_json::Value {
|
|
let mut map = serde_json::Map::new();
|
|
#(
|
|
map.insert(#db_names.to_string(), serde_json::to_value(&self.#field_idents).unwrap_or(serde_json::Value::Null));
|
|
)*
|
|
serde_json::Value::Object(map)
|
|
}
|
|
}
|
|
|
|
|
|
pub trait #bind_form_values_trait<'q, 'f>
|
|
where 'f: 'q {
|
|
fn #bind_form_values_func(self, form: &'f #update_form_name) -> Self;
|
|
}
|
|
|
|
impl<'q, 'f> #bind_form_values_trait<'q, 'f> for sqlx::query::Query<'q, #db, #db_args>
|
|
where 'f: 'q {
|
|
fn #bind_form_values_func(self, form: &'f #update_form_name) -> Self {
|
|
form.bind_form_values(self)
|
|
}
|
|
}
|
|
|
|
impl #impl_generics #name #ty_generics #where_clause {
|
|
pub async fn update<'a, E>(&self, executor: E, form: #update_form_name) -> Result<(), sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db>,
|
|
{
|
|
// Count parameters in the update statement
|
|
let update_stmt = form.update_stmt();
|
|
let param_count = update_stmt.matches(::sqlx_record::prelude::placeholder(1).chars().next().unwrap_or('?')).count();
|
|
|
|
let query_str = format!(
|
|
r#"UPDATE {}{}{} SET {} WHERE {} = {}"#,
|
|
#tq, Self::table_name(), #tq,
|
|
update_stmt,
|
|
#pk_db_name,
|
|
::sqlx_record::prelude::placeholder(param_count + 1)
|
|
);
|
|
|
|
let _q = match sqlx::query(&query_str)
|
|
.#bind_form_values_func(&form)
|
|
.bind(&self.#pk_field)
|
|
.execute(executor)
|
|
.await {
|
|
Ok(q) => q,
|
|
Err(err) => {
|
|
tracing::error!(r#"Error updating entity: {:?}
|
|
Query String: {}
|
|
"#, err, query_str);
|
|
return Err(err);
|
|
}
|
|
};
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn #update_by_func<'a, E>(executor: E, #pk_field: &#pk_type, form: #update_form_name) -> Result<(), sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db>,
|
|
{
|
|
// Count parameters in the update statement
|
|
let update_stmt = form.update_stmt();
|
|
let param_count = update_stmt.matches(::sqlx_record::prelude::placeholder(1).chars().next().unwrap_or('?')).count();
|
|
|
|
let query_str = format!(
|
|
r#"UPDATE {}{}{} SET {} WHERE {} = {}"#,
|
|
#tq, Self::table_name(), #tq,
|
|
update_stmt,
|
|
#pk_db_name,
|
|
::sqlx_record::prelude::placeholder(param_count + 1)
|
|
);
|
|
|
|
let _q = match sqlx::query(&query_str)
|
|
.#bind_form_values_func(&form)
|
|
.bind(#pk_field)
|
|
.execute(executor)
|
|
.await {
|
|
Ok(q) => q,
|
|
Err(err) => {
|
|
tracing::error!(r#"Error updating entity: {:?}
|
|
Query String: {}
|
|
"#, err, query_str);
|
|
return Err(err);
|
|
}
|
|
};
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn #multi_update_by_func<'a, E>(executor: E, #multi_pk_field: &Vec<#pk_type>, form: #update_form_name) -> Result<(), sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db>,
|
|
{
|
|
if #multi_pk_field.is_empty() {
|
|
return Ok(());
|
|
}
|
|
|
|
let update_stmt = form.update_stmt();
|
|
let form_param_count = update_stmt.matches(::sqlx_record::prelude::placeholder(1).chars().next().unwrap_or('?')).count();
|
|
let placeholders: String = (1..=#multi_pk_field.len())
|
|
.map(|i| ::sqlx_record::prelude::placeholder(form_param_count + i))
|
|
.collect::<Vec<_>>()
|
|
.join(",");
|
|
|
|
let query_str = format!(
|
|
r#"UPDATE {}{}{} SET {} WHERE {} IN ({})"#,
|
|
#tq, Self::table_name(), #tq,
|
|
update_stmt,
|
|
#pk_db_name, placeholders,
|
|
);
|
|
|
|
let mut q = sqlx::query(&query_str)
|
|
.#bind_form_values_func(&form);
|
|
|
|
for #pk_field in #multi_pk_field {
|
|
q = q.bind(#pk_field);
|
|
}
|
|
|
|
match q.execute(executor)
|
|
.await {
|
|
Ok(_q) => (),
|
|
Err(err) => {
|
|
tracing::error!(r#"Error updating entity: {:?}
|
|
Query String: {}
|
|
"#, err, query_str);
|
|
return Err(err);
|
|
}
|
|
};
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Update all records matching the filter conditions
|
|
/// Returns the number of affected rows
|
|
pub async fn update_by_filter<'a, E>(
|
|
executor: E,
|
|
filters: Vec<::sqlx_record::prelude::Filter<'a>>,
|
|
form: #update_form_name,
|
|
) -> Result<u64, sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database=#db>,
|
|
{
|
|
use ::sqlx_record::prelude::{Filter, bind_values};
|
|
|
|
if filters.is_empty() {
|
|
// Require at least one filter to prevent accidental table-wide updates
|
|
return Err(sqlx::Error::Protocol(
|
|
"update_by_filter requires at least one filter to prevent accidental table-wide updates".to_string()
|
|
));
|
|
}
|
|
|
|
let (update_stmt, form_values) = form.update_stmt_with_values();
|
|
if update_stmt.is_empty() {
|
|
return Ok(0);
|
|
}
|
|
|
|
let form_param_count = form_values.len();
|
|
let (where_conditions, filter_values) = Filter::build_where_clause_with_offset(&filters, form_param_count + 1);
|
|
|
|
let query_str = format!(
|
|
r#"UPDATE {}{}{} SET {} WHERE {}"#,
|
|
#tq, Self::table_name(), #tq,
|
|
update_stmt,
|
|
where_conditions,
|
|
);
|
|
|
|
// Combine form values and filter values
|
|
let mut all_values = form_values;
|
|
all_values.extend(filter_values);
|
|
|
|
let query = sqlx::query(&query_str);
|
|
let result = bind_values(query, &all_values)
|
|
.execute(executor)
|
|
.await?;
|
|
|
|
Ok(result.rows_affected())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Generate delete implementation - always generated for ALL entities
|
|
fn generate_delete_impl(
|
|
name: &Ident,
|
|
table_name: &str,
|
|
primary_key: &EntityField,
|
|
impl_generics: &ImplGenerics,
|
|
ty_generics: &TypeGenerics,
|
|
where_clause: &Option<&WhereClause>,
|
|
) -> TokenStream2 {
|
|
let pk_field = &primary_key.ident;
|
|
let pk_type = &primary_key.ty;
|
|
let pk_db_name = &primary_key.db_name;
|
|
let db = db_type();
|
|
let tq = table_quote();
|
|
|
|
let pk_field_name = primary_key.ident.to_string();
|
|
let hard_delete_by_func = format_ident!("hard_delete_by_{}", pk_field_name);
|
|
|
|
quote! {
|
|
impl #impl_generics #name #ty_generics #where_clause {
|
|
/// Hard delete - permanently removes the row from database
|
|
pub async fn hard_delete<'a, E>(&self, executor: E) -> Result<(), sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database = #db>,
|
|
{
|
|
Self::#hard_delete_by_func(executor, &self.#pk_field).await
|
|
}
|
|
|
|
/// Hard delete by primary key - permanently removes the row from database
|
|
pub async fn #hard_delete_by_func<'a, E>(executor: E, #pk_field: &#pk_type) -> Result<(), sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database = #db>,
|
|
{
|
|
let query = format!(
|
|
"DELETE FROM {}{}{} WHERE {} = {}",
|
|
#tq, #table_name, #tq,
|
|
#pk_db_name, ::sqlx_record::prelude::placeholder(1)
|
|
);
|
|
sqlx::query(&query).bind(#pk_field).execute(executor).await?;
|
|
Ok(())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Generate soft delete implementation
|
|
fn generate_soft_delete_impl(
|
|
name: &Ident,
|
|
table_name: &str,
|
|
primary_key: &EntityField,
|
|
soft_delete_field: Option<&EntityField>,
|
|
impl_generics: &ImplGenerics,
|
|
ty_generics: &TypeGenerics,
|
|
where_clause: &Option<&WhereClause>,
|
|
) -> TokenStream2 {
|
|
let Some(sd_field) = soft_delete_field else {
|
|
return quote! {};
|
|
};
|
|
|
|
let pk_field = &primary_key.ident;
|
|
let pk_type = &primary_key.ty;
|
|
let pk_db_name = &primary_key.db_name;
|
|
let sd_db_name = &sd_field.db_name;
|
|
let db = db_type();
|
|
let tq = table_quote();
|
|
|
|
let pk_field_name = primary_key.ident.to_string();
|
|
let soft_delete_by_func = format_ident!("soft_delete_by_{}", pk_field_name);
|
|
let restore_by_func = format_ident!("restore_by_{}", pk_field_name);
|
|
|
|
// Determine semantics based on field name and attribute:
|
|
// - #[soft_delete] attribute: field should be FALSE when deleted (user convention)
|
|
// - `is_active` by name: FALSE when deleted, TRUE when active
|
|
// - `is_deleted`/`deleted` by name: TRUE when deleted, FALSE when active
|
|
let sd_field_name = sd_field.ident.to_string();
|
|
let is_inverted = sd_field.is_soft_delete || sd_field_name == "is_active";
|
|
|
|
let (delete_value, restore_value) = if is_inverted {
|
|
("FALSE", "TRUE")
|
|
} else {
|
|
("TRUE", "FALSE")
|
|
};
|
|
|
|
quote! {
|
|
impl #impl_generics #name #ty_generics #where_clause {
|
|
/// Soft delete - marks record as deleted without removing from database
|
|
pub async fn soft_delete<'a, E>(&self, executor: E) -> Result<(), sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database = #db>,
|
|
{
|
|
Self::#soft_delete_by_func(executor, &self.#pk_field).await
|
|
}
|
|
|
|
/// Soft delete by primary key
|
|
pub async fn #soft_delete_by_func<'a, E>(executor: E, #pk_field: &#pk_type) -> Result<(), sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database = #db>,
|
|
{
|
|
let query = format!(
|
|
"UPDATE {}{}{} SET {} = {} WHERE {} = {}",
|
|
#tq, #table_name, #tq,
|
|
#sd_db_name, #delete_value,
|
|
#pk_db_name, ::sqlx_record::prelude::placeholder(1)
|
|
);
|
|
sqlx::query(&query).bind(#pk_field).execute(executor).await?;
|
|
Ok(())
|
|
}
|
|
|
|
/// Restore a soft-deleted record
|
|
pub async fn restore<'a, E>(&self, executor: E) -> Result<(), sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database = #db>,
|
|
{
|
|
Self::#restore_by_func(executor, &self.#pk_field).await
|
|
}
|
|
|
|
/// Restore by primary key
|
|
pub async fn #restore_by_func<'a, E>(executor: E, #pk_field: &#pk_type) -> Result<(), sqlx::Error>
|
|
where
|
|
E: sqlx::Executor<'a, Database = #db>,
|
|
{
|
|
let query = format!(
|
|
"UPDATE {}{}{} SET {} = {} WHERE {} = {}",
|
|
#tq, #table_name, #tq,
|
|
#sd_db_name, #restore_value,
|
|
#pk_db_name, ::sqlx_record::prelude::placeholder(1)
|
|
);
|
|
sqlx::query(&query).bind(#pk_field).execute(executor).await?;
|
|
Ok(())
|
|
}
|
|
|
|
/// Get the soft delete field name
|
|
pub const fn soft_delete_field() -> &'static str {
|
|
#sd_db_name
|
|
}
|
|
}
|
|
}
|
|
}
|