mod string_utils; extern crate proc_macro; use proc_macro::TokenStream; use proc_macro2::{Ident, TokenStream as TokenStream2}; use quote::{quote, format_ident}; use syn::{parse_macro_input, DeriveInput, Data, LitStr, Type, ImplGenerics, TypeGenerics, WhereClause}; use crate::string_utils::{pluralize, to_snake_case}; struct EntityField { ident: Ident, db_name: String, ty: Type, needs_type_annotation: bool, type_override: Option, is_primary_key: bool, is_version_field: bool, is_soft_delete: bool, is_created_at: bool, is_updated_at: bool, } /// Parse a string attribute that can be either: /// - `#[attr("value")]` (Meta::List style) /// - `#[attr = "value"]` (Meta::NameValue style) fn parse_string_attr(attr: &syn::Attribute) -> Option { match &attr.meta { syn::Meta::List(_) => { // #[attr("value")] style attr.parse_args::().ok().map(|lit| lit.value()) } syn::Meta::NameValue(nv) => { // #[attr = "value"] style if let syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit), .. }) = &nv.value { Some(lit.value()) } else { None } } _ => None, } } // Support both Update and Entity attributes #[proc_macro_derive(Update, attributes(rename, table_name, primary_key, field_type))] pub fn derive_update(input: TokenStream) -> TokenStream { derive_entity_internal(input) } #[proc_macro_derive(Entity, attributes(rename, table_name, primary_key, version, field_type, soft_delete, created_at, updated_at))] pub fn derive_entity(input: TokenStream) -> TokenStream { derive_entity_internal(input) } /// Generate database-specific types based on features fn db_type() -> TokenStream2 { #[cfg(feature = "postgres")] { quote! { sqlx::Postgres } } #[cfg(feature = "sqlite")] { return quote! { sqlx::Sqlite }; } #[cfg(feature = "mysql")] { quote! { sqlx::MySql } } #[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))] { // Default to MySql for backwards compatibility quote! { sqlx::MySql } } } fn db_arguments() -> TokenStream2 { #[cfg(feature = "postgres")] { quote! { sqlx::postgres::PgArguments } } #[cfg(feature = "sqlite")] { return quote! { sqlx::sqlite::SqliteArguments<'q> }; } #[cfg(feature = "mysql")] { quote! { sqlx::mysql::MySqlArguments } } #[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))] { quote! { sqlx::mysql::MySqlArguments } } } /// Get table quote character fn table_quote() -> &'static str { #[cfg(feature = "postgres")] { "\"" } #[cfg(feature = "sqlite")] { return "\""; } #[cfg(feature = "mysql")] { "`" } #[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))] { "`" } } /// Get compile-time placeholder for static-check SQL fn static_placeholder(index: usize) -> String { #[cfg(feature = "postgres")] { format!("${}", index) } #[cfg(not(feature = "postgres"))] { let _ = index; "?".to_string() } } fn derive_entity_internal(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); let name = &input.ident; let update_form_name = format_ident!("{}UpdateForm", name); let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let table_name = get_table_name(&input); let fields = parse_fields(&input); let primary_key = fields.iter() .find(|f| f.is_primary_key) .or_else(|| fields.iter().find(|f| f.ident == "id" || f.ident == "code")) .expect("Struct must have a primary key field, either explicitly specified or named 'id' or 'code'"); // Check for timestamp fields - either by attribute or by name let has_created_at = fields.iter().any(|f| f.is_created_at) || fields.iter().any(|f| f.ident == "created_at" && matches!(&f.ty, Type::Path(p) if p.path.is_ident("i64"))); let has_updated_at = fields.iter().any(|f| f.is_updated_at) || fields.iter().any(|f| f.ident == "updated_at" && matches!(&f.ty, Type::Path(p) if p.path.is_ident("i64"))); let version_field = fields.iter() .find(|f| f.is_version_field) .or_else(|| fields.iter().find(|&f| is_version_field(f))); // Find soft delete field (by attribute or by name convention) // Convention: `is_active` (FALSE = deleted), `is_deleted`/`deleted` (TRUE = deleted) let soft_delete_field = fields.iter() .find(|f| f.is_soft_delete) .or_else(|| fields.iter().find(|f| { (f.ident == "is_active" || f.ident == "is_deleted" || f.ident == "deleted") && matches!(&f.ty, Type::Path(p) if p.path.is_ident("bool")) })); // Generate all implementations let insert_impl = generate_insert_impl(&name, &table_name, primary_key, &fields, has_created_at, has_updated_at, &impl_generics, &ty_generics, &where_clause); let get_impl = generate_get_impl(&name, &table_name, primary_key, version_field, soft_delete_field, &fields, &impl_generics, &ty_generics, &where_clause); let update_impl = generate_update_impl(&name, &update_form_name, &table_name, &fields, primary_key, version_field, has_updated_at, &impl_generics, &ty_generics, &where_clause); let diff_impl = generate_diff_impl(&name, &update_form_name, &fields, primary_key, version_field, &impl_generics, &ty_generics, &where_clause); let delete_impl = generate_delete_impl(&name, &table_name, primary_key, &impl_generics, &ty_generics, &where_clause); let soft_delete_impl = generate_soft_delete_impl(&name, &table_name, primary_key, soft_delete_field, &impl_generics, &ty_generics, &where_clause); let pk_type = &primary_key.ty; let pk_field_name = &primary_key.ident; quote! { #insert_impl #get_impl #update_impl #diff_impl #delete_impl #soft_delete_impl impl #impl_generics #name #ty_generics #where_clause { pub const fn table_name() -> &'static str { #table_name } pub fn entity_key(#pk_field_name: &#pk_type) -> String { format!("/entities/{}/{}", #table_name, #pk_field_name) } pub fn entity_changes_table_name() -> String { format!("entity_changes_{}", #table_name) } } }.into() } fn get_table_name(input: &DeriveInput) -> String { input.attrs.iter() .find_map(|attr| { if attr.path().is_ident("table_name") { parse_string_attr(attr) } else { None } }) .unwrap_or_else(|| to_snake_case(&input.ident.to_string())) } fn parse_fields(input: &DeriveInput) -> Vec { match &input.data { Data::Struct(data_struct) => { data_struct.fields.iter().map(|field| { let ident = field.ident.as_ref().unwrap().clone(); let db_name = field.attrs.iter() .find_map(|attr| { if attr.path().is_ident("rename") { parse_string_attr(attr) } else { None } }) .unwrap_or_else(|| ident.to_string()); let type_override = field.attrs.iter() .find_map(|attr| { if attr.path().is_ident("field_type") { parse_string_attr(attr) } else { None } }); let needs_type_annotation = type_override.is_some() || { matches!(&field.ty, syn::Type::Path(p) if { let type_str = quote!(#p).to_string(); type_str.contains("Uuid") || type_str.contains("bool") }) }; let is_primary_key = field.attrs.iter() .any(|attr| attr.path().is_ident("primary_key")); let is_version_field = field.attrs.iter() .any(|attr| attr.path().is_ident("version")); let is_soft_delete = field.attrs.iter() .any(|attr| attr.path().is_ident("soft_delete")); let is_created_at = field.attrs.iter() .any(|attr| attr.path().is_ident("created_at")); let is_updated_at = field.attrs.iter() .any(|attr| attr.path().is_ident("updated_at")); EntityField { ident, db_name, ty: field.ty.clone(), needs_type_annotation, type_override, is_primary_key, is_version_field, is_soft_delete, is_created_at, is_updated_at, } }).collect() } _ => panic!("Entity can only be derived for structs"), } } fn is_version_field(f: &EntityField) -> bool { f.ident == "version" && matches!(&f.ty, Type::Path(p) if p.path.is_ident("u64") || p.path.is_ident("u32") || p.path.is_ident("i64") || p.path.is_ident("i32")) } // Generate the insert implementation fn generate_insert_impl( name: &Ident, table_name: &str, primary_key: &EntityField, fields: &[EntityField], _has_created_at: bool, _has_updated_at: bool, impl_generics: &ImplGenerics, ty_generics: &TypeGenerics, where_clause: &Option<&WhereClause>, ) -> TokenStream2 { let db_names: Vec<_> = fields.iter().map(|f| &f.db_name).collect(); let field_idents: Vec<_> = fields.iter().map(|f| &f.ident).collect(); let tq = table_quote(); let db = db_type(); let pk_db_name = &primary_key.db_name; let bindings: Vec<_> = fields.iter().map(|f| { let ident = &f.ident; quote! { &self.#ident } }).collect(); let pk_field = &primary_key.ident; let pk_type = &primary_key.ty; let field_count = db_names.len(); quote! { impl #impl_generics #name #ty_generics #where_clause { pub async fn insert<'a, E>(&self, executor: E) -> Result<#pk_type, sqlx::Error> where E: sqlx::Executor<'a, Database=#db>, { let placeholders: String = (1..=#field_count) .map(|i| ::sqlx_record::prelude::placeholder(i)) .collect::>() .join(", "); let insert_stmt = format!( "INSERT INTO {}{}{} ({}) VALUES ({})", #tq, #table_name, #tq, vec![#(#db_names),*].join(", "), placeholders ); let result = sqlx::query(&insert_stmt) #(.bind(#bindings))* .execute(executor) .await?; Ok(self.#pk_field.clone()) } /// Insert multiple entities in a single statement pub async fn insert_many<'a, E>(executor: E, entities: &[Self]) -> Result, sqlx::Error> where E: sqlx::Executor<'a, Database=#db>, Self: Clone, { if entities.is_empty() { return Ok(vec![]); } let field_count = #field_count; let mut placeholders = Vec::with_capacity(entities.len()); let mut current_idx = 1usize; for _ in entities { let row_placeholders: String = (0..field_count) .map(|_| { let ph = ::sqlx_record::prelude::placeholder(current_idx); current_idx += 1; ph }) .collect::>() .join(", "); placeholders.push(format!("({})", row_placeholders)); } let insert_stmt = format!( "INSERT INTO {}{}{} ({}) VALUES {}", #tq, #table_name, #tq, vec![#(#db_names),*].join(", "), placeholders.join(", ") ); let mut query = sqlx::query(&insert_stmt); for entity in entities { #(query = query.bind(&entity.#field_idents);)* } query.execute(executor).await?; Ok(entities.iter().map(|e| e.#pk_field.clone()).collect()) } /// Insert or update on primary key conflict (upsert) pub async fn upsert<'a, E>(&self, executor: E) -> Result<#pk_type, sqlx::Error> where E: sqlx::Executor<'a, Database=#db>, { let placeholders: String = (1..=#field_count) .map(|i| ::sqlx_record::prelude::placeholder(i)) .collect::>() .join(", "); let non_pk_fields: Vec<&str> = vec![#(#db_names),*] .into_iter() .filter(|f| *f != #pk_db_name) .collect(); let upsert_stmt = ::sqlx_record::prelude::build_upsert_stmt( #table_name, &[#(#db_names),*], #pk_db_name, &non_pk_fields, &placeholders, ); sqlx::query(&upsert_stmt) #(.bind(#bindings))* .execute(executor) .await?; Ok(self.#pk_field.clone()) } /// Alias for upsert pub async fn insert_or_update<'a, E>(&self, executor: E) -> Result<#pk_type, sqlx::Error> where E: sqlx::Executor<'a, Database=#db>, { self.upsert(executor).await } } } } fn get_type_string(field: &EntityField) -> String { if field.needs_type_annotation { if let Some(override_type) = &field.type_override { override_type.clone() } else { let ty = &field.ty; let type_str = quote!(#ty).to_string(); // Remove all whitespace let clean_type = type_str.replace(" ", ""); if clean_type.starts_with("Option<") && clean_type.ends_with(">") { // Extract inner type between < and > clean_type[7..clean_type.len()-1].to_string() } else { type_str } } } else { String::new() } } // Generate the get implementations fn generate_get_impl( name: &Ident, table_name: &str, primary_key: &EntityField, version_field: Option<&EntityField>, _soft_delete_field: Option<&EntityField>, // Reserved for future auto-filtering fields: &[EntityField], impl_generics: &ImplGenerics, ty_generics: &TypeGenerics, where_clause: &Option<&WhereClause>, ) -> TokenStream2 { let select_fields = fields.iter().map(|f| { if f.needs_type_annotation { let type_str = get_type_string(f); format!("{} as \"{}: {}\"", f.db_name, f.ident, type_str) } else { f.db_name.clone() } }); let pk_field_name = primary_key.ident.to_string(); let multi_pk_field_name = Some(pluralize(&pk_field_name)) .filter(|x| x != &pk_field_name) .unwrap_or_else(|| format!("{}_list", pk_field_name)); let get_by_func = format_ident!("get_by_{}", pk_field_name); let multi_get_by_func = format_ident!("get_by_{}", multi_pk_field_name); let pk_field = &primary_key.ident; let pk_type = &primary_key.ty; let pk_field_name = primary_key.ident.to_string(); let pk_db_field_name = &primary_key.db_name; let tq = table_quote(); let db = db_type(); let new_fields = select_fields.clone().collect::>(); let select_fields_str = new_fields.iter() .filter_map(|e| e.split(" ") .next()).collect::>().join(", "); let select_field_list = select_fields.clone().collect::>(); // Generate the get_version function if version_field exists let get_version_impl = if let Some(vfield) = version_field { let version_field_type = &vfield.ty; let version_db_name = &vfield.db_name; quote! { pub async fn get_version<'a, E>(executor: E, #pk_field: &#pk_type) -> Result, sqlx::Error> where E: sqlx::Executor<'a, Database=#db>, { let query = format!( r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#, #version_db_name, #tq, Self::table_name(), #tq, #pk_db_field_name, ::sqlx_record::prelude::placeholder(1) ); let result = sqlx::query_scalar(&query) .bind(#pk_field) .fetch_optional(executor) .await?; Ok(result) } pub async fn get_versions<'a, E>(executor: E, keys: &Vec<#pk_type>) -> Result<::std::collections::HashMap<#pk_type, #version_field_type>, sqlx::Error> where E: sqlx::Executor<'a, Database=#db>, { if keys.is_empty() { return Ok(::std::collections::HashMap::new()); } use sqlx::Row; let placeholders: String = (1..=keys.len()) .map(|i| ::sqlx_record::prelude::placeholder(i)) .collect::>() .join(","); let query = format!( r#"SELECT DISTINCT {}, {} FROM {}{}{} WHERE {} IN ({})"#, #pk_db_field_name, #version_db_name, #tq, Self::table_name(), #tq, #pk_db_field_name, placeholders ); let mut query_builder = sqlx::query(&query); for key in keys { query_builder = query_builder.bind(key); } let rows = query_builder .fetch_all(executor) .await?; let mut result = ::std::collections::HashMap::with_capacity(rows.len()); for row in rows { let key: #pk_type = row.try_get(0)?; let version: #version_field_type = row.try_get(1)?; result.insert(key, version); } Ok(result) } } } else { // If no version field, generate empty implementation quote! {} }; let field_list = fields.iter().map(|f| f.db_name.clone()).collect::>(); // Check if static-check feature is enabled at macro expansion time let use_static_validation = cfg!(feature = "static-check"); let get_by_impl = if use_static_validation { // Static validation: use sqlx::query_as! with compile-time checked SQL let select_stmt = format!( r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#, select_fields.clone().collect::>().join(", "), tq, table_name, tq, pk_db_field_name, static_placeholder(1) ); quote! { pub async fn #get_by_func<'a, E>(executor: E, #pk_field: &#pk_type) -> Result, sqlx::Error> where E: sqlx::Executor<'a, Database=#db>, { let result = sqlx::query_as!( Self, #select_stmt, #pk_field ) .fetch_optional(executor) .await?; Ok(result) } pub async fn get_by_primary_key<'a, E>(executor: E, #pk_field: &#pk_type) -> Result, sqlx::Error> where E: sqlx::Executor<'a, Database=#db>, { let result = sqlx::query_as!( Self, #select_stmt, #pk_field ) .fetch_optional(executor) .await?; Ok(result) } } } else { // Runtime: use sqlx::query_as with dynamic SQL quote! { pub async fn #get_by_func<'a, E>(executor: E, #pk_field: &#pk_type) -> Result, sqlx::Error> where E: sqlx::Executor<'a, Database=#db>, { let select_stmt = format!( r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#, vec![#(#field_list),*].join(","), #tq, #table_name, #tq, #pk_db_field_name, ::sqlx_record::prelude::placeholder(1) ); let result = sqlx::query_as::<_, Self>(&select_stmt) .bind(#pk_field) .fetch_optional(executor) .await?; Ok(result) } pub async fn get_by_primary_key<'a, E>(executor: E, #pk_field: &#pk_type) -> Result, sqlx::Error> where E: sqlx::Executor<'a, Database=#db>, { let select_stmt = format!( r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#, vec![#(#field_list),*].join(","), #tq, #table_name, #tq, #pk_db_field_name, ::sqlx_record::prelude::placeholder(1) ); let result = sqlx::query_as::<_, Self>(&select_stmt) .bind(#pk_field) .fetch_optional(executor) .await?; Ok(result) } } }; quote! { impl #impl_generics #name #ty_generics #where_clause { #get_by_impl pub const fn primary_key_field() -> &'static str { #pk_field_name } pub const fn primary_key_db_field() -> &'static str { #pk_db_field_name } pub fn primary_key(&self) -> &#pk_type { &self.#pk_field } pub fn select_fields() -> Vec<&'static str> { vec![#(#select_field_list),*] } pub async fn #multi_get_by_func<'a, E>(executor: E, ids: &[#pk_type]) -> Result, sqlx::Error> where E: sqlx::Executor<'a, Database=#db>, { if ids.is_empty() { return Ok(vec![]); } let placeholders: String = (1..=ids.len()) .map(|i| ::sqlx_record::prelude::placeholder(i)) .collect::>() .join(","); let query = format!( r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} IN ({})"#, #select_fields_str, #tq, Self::table_name(), #tq, #pk_db_field_name, placeholders ); let mut q = sqlx::query_as::<_, Self>(&query); for id in ids { q = q.bind(id); } q.fetch_all(executor).await } #get_version_impl pub async fn find<'a, E>( executor: E, filters: Vec<::sqlx_record::prelude::Filter<'a>>, index: Option<&str>, ) -> Result, sqlx::Error> where E: sqlx::Executor<'a, Database=#db>, { Self::find_ordered_with_limit(executor, filters, index, Vec::new(), None).await } pub async fn find_ordered<'a, E>( executor: E, filters: Vec<::sqlx_record::prelude::Filter<'a>>, index: Option<&str>, order_by: Vec<(&str, bool)>, ) -> Result, sqlx::Error> where E: sqlx::Executor<'a, Database=#db>, { Self::find_ordered_with_limit(executor, filters, index, order_by, None).await } pub async fn find_one<'a, E>( executor: E, filters: Vec<::sqlx_record::prelude::Filter<'a>>, index: Option<&str>, ) -> Result, sqlx::Error> where E: sqlx::Executor<'a, Database=#db>, { let found = Self::find_ordered_with_limit(executor, filters, index, Vec::new(), Some((0, 1))).await?; Ok(found.into_iter().next()) } pub async fn find_ordered_with_limit<'a, E>( executor: E, filters: Vec<::sqlx_record::prelude::Filter<'a>>, index: Option<&str>, order_by: Vec<(&str, bool)>, offset_limit: Option<(u32, u32)>, ) -> Result, sqlx::Error> where E: sqlx::Executor<'a, Database=#db>, { use ::sqlx_record::prelude::{Filter, Value, bind_as_values}; let (where_conditions, values) = Filter::build_where_clause(&filters); let where_clause = if !where_conditions.is_empty() { format!("WHERE {}", where_conditions) } else { String::new() }; let index_clause = ::sqlx_record::prelude::build_index_clause(index); //Filter order_by fields to only those managed let fields = Self::select_fields().into_iter().collect::<::std::collections::HashSet<_>>(); let order_by = order_by.iter() .filter(|(field, _)| fields.contains(field)) .collect::>(); let order_by_clause = if !order_by.is_empty() { let order_by_str = order_by.iter() .map(|(field, asc)| format!("{} {}", field, if *asc { "ASC" } else { "DESC" })) .collect::>() .join(", "); format!("ORDER BY {}", order_by_str) } else { String::new() }; let query = format!( r#"SELECT DISTINCT {} FROM {}{}{} {} {} {} {}"#, #select_fields_str, #tq, #table_name, #tq, index_clause, where_clause, order_by_clause, offset_limit.map(|(offset, limit)| format!("LIMIT {} OFFSET {}", limit, offset)).unwrap_or_default(), ); let db_query = sqlx::query_as(&query); // Bind values to the query let results = match bind_as_values(db_query, &values) .fetch_all(executor) .await { Ok(results) => results, Err(err) => { tracing::error!(r#"Error executing Query: {} Error => {:?} "#, query, err); return Err(err); } }; Ok(results) } pub async fn count<'a, E>( executor: E, filters: Vec<::sqlx_record::prelude::Filter<'a>>, index: Option<&str>, ) -> Result where E: sqlx::Executor<'a, Database=#db>, { use ::sqlx_record::prelude::{Filter, Value, bind_scalar_values}; let (where_conditions, values) = Filter::build_where_clause(&filters); let where_clause = if !where_conditions.is_empty() { format!("WHERE {}", where_conditions) } else { String::new() }; let index_clause = ::sqlx_record::prelude::build_index_clause(index); let count_expr = ::sqlx_record::prelude::build_count_expr(#pk_db_field_name); let query = format!( r#"SELECT {} FROM {}{}{} {} {}"#, count_expr, #tq, #table_name, #tq, index_clause, where_clause, ); let db_query = sqlx::query_scalar::<_, i64>(&query); // Bind values to the query let count = match bind_scalar_values(db_query, &values) .fetch_optional(executor) .await { Ok(count) => count.unwrap_or(0) as u64, Err(err) => { tracing::error!(r#"Error executing Query: {} Error => {:?} "#, query, err); return Err(err); } }; Ok(count) } /// Paginate results with total count pub async fn paginate<'a, E>( executor: E, filters: Vec<::sqlx_record::prelude::Filter<'a>>, index: Option<&str>, order_by: Vec<(&str, bool)>, page_request: ::sqlx_record::prelude::PageRequest, ) -> Result<::sqlx_record::prelude::Page, sqlx::Error> where E: sqlx::Executor<'a, Database=#db> + Copy, { // Get total count first let total_count = Self::count(executor, filters.clone(), index).await?; // Get page items let items = Self::find_ordered_with_limit( executor, filters, index, order_by, Some((page_request.offset(), page_request.limit())), ).await?; Ok(::sqlx_record::prelude::Page::new( items, total_count, page_request.page, page_request.page_size, )) } /// Select specific fields only (returns raw rows) /// Use `sqlx::Row` trait to access fields: `row.try_get::("name")?` pub async fn find_partial<'a, E>( executor: E, select_fields: &[&str], filters: Vec<::sqlx_record::prelude::Filter<'a>>, index: Option<&str>, ) -> Result::Row>, sqlx::Error> where E: sqlx::Executor<'a, Database=#db>, { use ::sqlx_record::prelude::{Filter, bind_values}; // Validate fields exist let valid_fields: ::std::collections::HashSet<_> = Self::select_fields().into_iter().collect(); let selected: Vec<_> = select_fields.iter() .filter(|f| valid_fields.contains(*f)) .copied() .collect(); if selected.is_empty() { return Ok(vec![]); } let (where_conditions, values) = Filter::build_where_clause(&filters); let where_clause = if !where_conditions.is_empty() { format!("WHERE {}", where_conditions) } else { String::new() }; let index_clause = ::sqlx_record::prelude::build_index_clause(index); let query = format!( "SELECT DISTINCT {} FROM {}{}{} {} {}", selected.join(", "), #tq, #table_name, #tq, index_clause, where_clause, ); let db_query = sqlx::query(&query); bind_values(db_query, &values) .fetch_all(executor) .await } } } } /// Check if a type is a binary type (Vec) fn is_binary_type(ty: &Type) -> bool { if let Type::Path(type_path) = ty { let path_str = quote!(#type_path).to_string().replace(" ", ""); path_str == "Vec" || path_str == "std::vec::Vec" } else { false } } fn generate_update_impl( name: &Ident, update_form_name: &Ident, table_name: &str, fields: &[EntityField], primary_key: &EntityField, version_field: Option<&EntityField>, has_updated_at: bool, impl_generics: &ImplGenerics, ty_generics: &TypeGenerics, where_clause: &Option<&WhereClause>, ) -> TokenStream2 { let update_fields: Vec<_> = fields.iter() .filter(|f| f.ident != primary_key.ident && f.ident != "created_at" && version_field.as_ref().map(|vf| f.ident != vf.ident).unwrap_or(true)) .collect(); let field_idents: Vec<_> = update_fields.iter().map(|f| &f.ident).collect(); let field_types: Vec<_> = update_fields.iter().map(|f| &f.ty).collect(); let db_names: Vec<_> = update_fields.iter().map(|f| &f.db_name).collect(); let db = db_type(); let db_args = db_arguments(); let setter_methods: Vec<_> = update_fields.iter().map(|field| { let method_name = format_ident!("set_{}", field.ident); let field_type = &field.ty; let field_ident = &field.ident; quote! { pub fn #method_name(&mut self, value: T) -> () where T: Into<#field_type>, { self.#field_ident = Some(value.into()); } } }).collect(); let builder_methods = update_fields.iter().map(|field| { let method_name = format_ident!("with_{}", field.ident); let field_type = &field.ty; let field_ident = &field.ident; quote! { pub fn #method_name(mut self, value: T) -> Self where T: Into<#field_type>, { self.#field_ident = Some(value.into()); self } } }); // Generate eval_* methods for non-binary fields let eval_methods: Vec<_> = update_fields.iter() .filter(|f| !is_binary_type(&f.ty)) .map(|field| { let method_name = format_ident!("eval_{}", field.ident); let db_name = &field.db_name; quote! { /// Set field using an UpdateExpr for complex operations (arithmetic, CASE/WHEN, etc.) /// Takes precedence over with_* if both are set. pub fn #method_name(mut self, expr: ::sqlx_record::prelude::UpdateExpr) -> Self { self._exprs.insert(#db_name, expr); self } } }) .collect(); // Version increment - use CASE WHEN for cross-database compatibility let version_increment = if let Some(vfield) = version_field { let version_db_name = &vfield.db_name; let version_type = &vfield.ty; // Handle different integer types with appropriate max values for wrapping let max_val = match version_type { Type::Path(type_path) if type_path.path.is_ident("u32") => "4294967295", Type::Path(type_path) if type_path.path.is_ident("u64") => "18446744073709551615", Type::Path(type_path) if type_path.path.is_ident("i32") => "2147483647", Type::Path(type_path) if type_path.path.is_ident("i64") => "9223372036854775807", _ => "", }; if max_val.is_empty() { quote! { parts.push(format!("{} = {} + 1", #version_db_name, #version_db_name)); } } else { quote! { parts.push(format!("{} = CASE WHEN {} = {} THEN 0 ELSE {} + 1 END", #version_db_name, #version_db_name, #max_val, #version_db_name)); } } } else { quote! {} }; // Auto-update updated_at timestamp (only if not manually set) let updated_at_increment = if has_updated_at { quote! { // Only auto-set updated_at if not already set in form or via expression if self.updated_at.is_none() && !self._exprs.contains_key("updated_at") { parts.push(format!("updated_at = {}", ::sqlx_record::prelude::placeholder(idx))); values.push(::sqlx_record::prelude::Value::Int64(chrono::Utc::now().timestamp_millis())); idx += 1; } } } else { quote! {} }; quote! { /// Update form with support for simple value updates and complex expressions pub struct #update_form_name #ty_generics #where_clause { #(pub #field_idents: Option<#field_types>,)* /// Expression-based updates (eval_* methods). Takes precedence over with_* for same field. pub _exprs: std::collections::HashMap<&'static str, ::sqlx_record::prelude::UpdateExpr>, } impl #impl_generics Default for #update_form_name #ty_generics #where_clause { fn default() -> Self { Self { #(#field_idents: None,)* _exprs: std::collections::HashMap::new(), } } } impl #impl_generics #name #ty_generics #where_clause { pub fn update_form() -> #update_form_name { #update_form_name::default() } } impl #impl_generics #update_form_name #ty_generics #where_clause { pub fn new() -> Self { Self::default() } #(#setter_methods)* #(#builder_methods)* #(#eval_methods)* /// Raw SQL expression escape hatch for any field (no bind parameters). /// For expressions with parameters, use `raw_with_values()`. pub fn raw(mut self, field: &'static str, sql: impl Into) -> Self { self._exprs.insert(field, ::sqlx_record::prelude::UpdateExpr::Raw { sql: sql.into(), values: vec![], }); self } /// Raw SQL expression escape hatch with bind parameters. /// Use `?` for placeholders - they will be converted to proper database placeholders. pub fn raw_with_values(mut self, field: &'static str, sql: impl Into, values: Vec<::sqlx_record::prelude::Value>) -> Self { self._exprs.insert(field, ::sqlx_record::prelude::UpdateExpr::Raw { sql: sql.into(), values, }); self } /// Generate UPDATE SET clause and collect values for binding. /// Expression fields (eval_*) take precedence over simple value fields (with_*). pub fn update_stmt_with_values(&self) -> (String, Vec<::sqlx_record::prelude::Value>) { let mut parts = Vec::new(); let mut values = Vec::new(); let mut idx = 1usize; #( // Check if this field has an expression (takes precedence) if let Some(expr) = self._exprs.get(#db_names) { let (sql, expr_values) = expr.build_sql(#db_names, idx); parts.push(format!("{} = {}", #db_names, sql)); idx += expr_values.len(); values.extend(expr_values); } else if let Some(ref value) = self.#field_idents { parts.push(format!("{} = {}", #db_names, ::sqlx_record::prelude::placeholder(idx))); values.push(::sqlx_record::prelude::Value::from(value.clone())); idx += 1; } )* #version_increment #updated_at_increment (parts.join(", "), values) } /// Generate UPDATE SET clause (backward compatible, without values). /// For new code, prefer `update_stmt_with_values()`. pub fn update_stmt(&self) -> String { self.update_stmt_with_values().0 } /// Bind all form values to query in correct order. /// Handles both simple values and expression values, respecting expression precedence. /// Uses Value enum for proper type handling of Option fields. pub fn bind_all_values<'q>(&'q self, mut query: sqlx::query::Query<'q, #db, #db_args>) -> sqlx::query::Query<'q, #db, #db_args> { // Use update_stmt_with_values to get properly converted values // This handles nested Options (Option>) correctly let (_, values) = self.update_stmt_with_values(); for value in values { query = ::sqlx_record::prelude::bind_value_owned(query, value); } query } /// Legacy binding method - binds values through the Value enum for proper type handling. /// For backward compatibility. New code should use bind_all_values(). pub fn bind_form_values<'q>(&'q self, mut query: sqlx::query::Query<'q, #db, #db_args>) -> sqlx::query::Query<'q, #db, #db_args> { // Always use Value-based binding to properly handle Option fields // This ensures nested Options (Option>) are unwrapped correctly let (_, values) = self.update_stmt_with_values(); for value in values { query = ::sqlx_record::prelude::bind_value_owned(query, value); } query } /// Check if this form uses any expressions pub fn has_expressions(&self) -> bool { !self._exprs.is_empty() } /// Get the number of bind parameters this form will use. pub fn param_count(&self) -> usize { self.update_stmt_with_values().1.len() } pub const fn table_name() -> &'static str { #table_name } } } } fn generate_diff_impl( name: &Ident, update_form_name: &Ident, fields: &[EntityField], primary_key: &EntityField, version_field: Option<&EntityField>, impl_generics: &ImplGenerics, ty_generics: &TypeGenerics, where_clause: &Option<&WhereClause>, ) -> TokenStream2 { let update_fields: Vec<_> = fields.iter() .filter(|f| f.ident != primary_key.ident && f.ident != "created_at" && version_field.as_ref().map(|vf| f.ident != vf.ident).unwrap_or(true)) .collect(); let field_idents: Vec<_> = update_fields.iter().map(|f| &f.ident).collect(); let db_names: Vec<_> = update_fields.iter().map(|f| &f.db_name).collect(); let field_types: Vec<_> = update_fields.iter().map(|f| &f.ty).collect(); let bind_form_values_trait = format_ident!("{}BindFormValues", name); let bind_form_values_func = format_ident!("{}", to_snake_case(&name.to_string())); let pk_field = primary_key.ident.clone(); let pk_type = primary_key.ty.clone(); let pk_db_name = primary_key.db_name.clone(); let multi_pk_field= format_ident!("{}", pluralize(&pk_field.to_string())); let update_by_func = format_ident!("update_by_{}", pk_field); let multi_update_by_func = format_ident!("update_by_{}", multi_pk_field); let db = db_type(); let db_args = db_arguments(); let tq = table_quote(); quote! { impl #impl_generics #update_form_name #ty_generics #where_clause { pub fn model_diff(&self, other: &#name #ty_generics) -> serde_json::Value { let mut changes = serde_json::Map::new(); #( if let Some(ref value) = &self.#field_idents { if &other.#field_idents != value { changes.insert(#db_names.to_string(), serde_json::json!(value)); } } )* serde_json::Value::Object(changes) } pub fn diff_modify(&mut self, model: &#name #ty_generics) -> serde_json::Value { let mut changes = serde_json::Map::new(); #( if let Some(ref value) = self.#field_idents { if &model.#field_idents == value { self.#field_idents = None; } else { changes.insert(#db_names.to_string(), serde_json::json!(value)); } } )* serde_json::Value::Object(changes) } pub async fn db_diff<'q, E>(&self, #pk_field: &#pk_type, executor: E) -> Result where E: sqlx::Executor<'q, Database=#db>, { use sqlx::Row; let fields_to_fetch: Vec = vec![ #( if self.#field_idents.is_some() { #db_names.to_string() } else { String::new() } ),* ].into_iter().filter(|s| !s.is_empty()).collect(); if fields_to_fetch.is_empty() { return Ok(serde_json::json!({})); } let query = format!( "SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}", fields_to_fetch.join(", "), #tq, Self::table_name(), #tq, #pk_db_name, ::sqlx_record::prelude::placeholder(1) ); let row = sqlx::query(&query) .bind(#pk_field) .fetch_one(executor) .await?; let mut changes = serde_json::Map::new(); #( if let Some(value) = &self.#field_idents { if fields_to_fetch.contains(&#db_names.to_string()) { let db_value: #field_types = row.try_get(#db_names)?; if *value != db_value { changes.insert(#db_names.to_string(), serde_json::json!(value)); } } } )* Ok(serde_json::Value::Object(changes)) } } impl #impl_generics #name #ty_generics #where_clause { pub fn to_update_form(&self) -> #update_form_name #ty_generics { #update_form_name { #(#field_idents: Some(self.#field_idents.clone()),)* _exprs: std::collections::HashMap::new(), } } pub fn initial_diff(&self) -> serde_json::Value { let mut map = serde_json::Map::new(); #( map.insert(#db_names.to_string(), serde_json::to_value(&self.#field_idents).unwrap_or(serde_json::Value::Null)); )* serde_json::Value::Object(map) } } pub trait #bind_form_values_trait<'q, 'f> where 'f: 'q { fn #bind_form_values_func(self, form: &'f #update_form_name) -> Self; } impl<'q, 'f> #bind_form_values_trait<'q, 'f> for sqlx::query::Query<'q, #db, #db_args> where 'f: 'q { fn #bind_form_values_func(self, form: &'f #update_form_name) -> Self { form.bind_form_values(self) } } impl #impl_generics #name #ty_generics #where_clause { pub async fn update<'a, E>(&self, executor: E, form: #update_form_name) -> Result<(), sqlx::Error> where E: sqlx::Executor<'a, Database=#db>, { // Count parameters in the update statement let update_stmt = form.update_stmt(); let param_count = update_stmt.matches(::sqlx_record::prelude::placeholder(1).chars().next().unwrap_or('?')).count(); let query_str = format!( r#"UPDATE {}{}{} SET {} WHERE {} = {}"#, #tq, Self::table_name(), #tq, update_stmt, #pk_db_name, ::sqlx_record::prelude::placeholder(param_count + 1) ); let _q = match sqlx::query(&query_str) .#bind_form_values_func(&form) .bind(&self.#pk_field) .execute(executor) .await { Ok(q) => q, Err(err) => { tracing::error!(r#"Error updating entity: {:?} Query String: {} "#, err, query_str); return Err(err); } }; Ok(()) } pub async fn #update_by_func<'a, E>(executor: E, #pk_field: &#pk_type, form: #update_form_name) -> Result<(), sqlx::Error> where E: sqlx::Executor<'a, Database=#db>, { // Count parameters in the update statement let update_stmt = form.update_stmt(); let param_count = update_stmt.matches(::sqlx_record::prelude::placeholder(1).chars().next().unwrap_or('?')).count(); let query_str = format!( r#"UPDATE {}{}{} SET {} WHERE {} = {}"#, #tq, Self::table_name(), #tq, update_stmt, #pk_db_name, ::sqlx_record::prelude::placeholder(param_count + 1) ); let _q = match sqlx::query(&query_str) .#bind_form_values_func(&form) .bind(#pk_field) .execute(executor) .await { Ok(q) => q, Err(err) => { tracing::error!(r#"Error updating entity: {:?} Query String: {} "#, err, query_str); return Err(err); } }; Ok(()) } pub async fn #multi_update_by_func<'a, E>(executor: E, #multi_pk_field: &Vec<#pk_type>, form: #update_form_name) -> Result<(), sqlx::Error> where E: sqlx::Executor<'a, Database=#db>, { if #multi_pk_field.is_empty() { return Ok(()); } let update_stmt = form.update_stmt(); let form_param_count = update_stmt.matches(::sqlx_record::prelude::placeholder(1).chars().next().unwrap_or('?')).count(); let placeholders: String = (1..=#multi_pk_field.len()) .map(|i| ::sqlx_record::prelude::placeholder(form_param_count + i)) .collect::>() .join(","); let query_str = format!( r#"UPDATE {}{}{} SET {} WHERE {} IN ({})"#, #tq, Self::table_name(), #tq, update_stmt, #pk_db_name, placeholders, ); let mut q = sqlx::query(&query_str) .#bind_form_values_func(&form); for #pk_field in #multi_pk_field { q = q.bind(#pk_field); } match q.execute(executor) .await { Ok(_q) => (), Err(err) => { tracing::error!(r#"Error updating entity: {:?} Query String: {} "#, err, query_str); return Err(err); } }; Ok(()) } /// Update all records matching the filter conditions /// Returns the number of affected rows pub async fn update_by_filter<'a, E>( executor: E, filters: Vec<::sqlx_record::prelude::Filter<'a>>, form: #update_form_name, ) -> Result where E: sqlx::Executor<'a, Database=#db>, { use ::sqlx_record::prelude::{Filter, bind_values}; if filters.is_empty() { // Require at least one filter to prevent accidental table-wide updates return Err(sqlx::Error::Protocol( "update_by_filter requires at least one filter to prevent accidental table-wide updates".to_string() )); } let (update_stmt, form_values) = form.update_stmt_with_values(); if update_stmt.is_empty() { return Ok(0); } let form_param_count = form_values.len(); let (where_conditions, filter_values) = Filter::build_where_clause_with_offset(&filters, form_param_count + 1); let query_str = format!( r#"UPDATE {}{}{} SET {} WHERE {}"#, #tq, Self::table_name(), #tq, update_stmt, where_conditions, ); // Combine form values and filter values let mut all_values = form_values; all_values.extend(filter_values); let query = sqlx::query(&query_str); let result = bind_values(query, &all_values) .execute(executor) .await?; Ok(result.rows_affected()) } } } } // Generate delete implementation - always generated for ALL entities fn generate_delete_impl( name: &Ident, table_name: &str, primary_key: &EntityField, impl_generics: &ImplGenerics, ty_generics: &TypeGenerics, where_clause: &Option<&WhereClause>, ) -> TokenStream2 { let pk_field = &primary_key.ident; let pk_type = &primary_key.ty; let pk_db_name = &primary_key.db_name; let db = db_type(); let tq = table_quote(); let pk_field_name = primary_key.ident.to_string(); let hard_delete_by_func = format_ident!("hard_delete_by_{}", pk_field_name); quote! { impl #impl_generics #name #ty_generics #where_clause { /// Hard delete - permanently removes the row from database pub async fn hard_delete<'a, E>(&self, executor: E) -> Result<(), sqlx::Error> where E: sqlx::Executor<'a, Database = #db>, { Self::#hard_delete_by_func(executor, &self.#pk_field).await } /// Hard delete by primary key - permanently removes the row from database pub async fn #hard_delete_by_func<'a, E>(executor: E, #pk_field: &#pk_type) -> Result<(), sqlx::Error> where E: sqlx::Executor<'a, Database = #db>, { let query = format!( "DELETE FROM {}{}{} WHERE {} = {}", #tq, #table_name, #tq, #pk_db_name, ::sqlx_record::prelude::placeholder(1) ); sqlx::query(&query).bind(#pk_field).execute(executor).await?; Ok(()) } } } } // Generate soft delete implementation fn generate_soft_delete_impl( name: &Ident, table_name: &str, primary_key: &EntityField, soft_delete_field: Option<&EntityField>, impl_generics: &ImplGenerics, ty_generics: &TypeGenerics, where_clause: &Option<&WhereClause>, ) -> TokenStream2 { let Some(sd_field) = soft_delete_field else { return quote! {}; }; let pk_field = &primary_key.ident; let pk_type = &primary_key.ty; let pk_db_name = &primary_key.db_name; let sd_db_name = &sd_field.db_name; let db = db_type(); let tq = table_quote(); let pk_field_name = primary_key.ident.to_string(); let soft_delete_by_func = format_ident!("soft_delete_by_{}", pk_field_name); let restore_by_func = format_ident!("restore_by_{}", pk_field_name); // Determine semantics based on field name and attribute: // - #[soft_delete] attribute: field should be FALSE when deleted (user convention) // - `is_active` by name: FALSE when deleted, TRUE when active // - `is_deleted`/`deleted` by name: TRUE when deleted, FALSE when active let sd_field_name = sd_field.ident.to_string(); let is_inverted = sd_field.is_soft_delete || sd_field_name == "is_active"; let (delete_value, restore_value) = if is_inverted { ("FALSE", "TRUE") } else { ("TRUE", "FALSE") }; quote! { impl #impl_generics #name #ty_generics #where_clause { /// Soft delete - marks record as deleted without removing from database pub async fn soft_delete<'a, E>(&self, executor: E) -> Result<(), sqlx::Error> where E: sqlx::Executor<'a, Database = #db>, { Self::#soft_delete_by_func(executor, &self.#pk_field).await } /// Soft delete by primary key pub async fn #soft_delete_by_func<'a, E>(executor: E, #pk_field: &#pk_type) -> Result<(), sqlx::Error> where E: sqlx::Executor<'a, Database = #db>, { let query = format!( "UPDATE {}{}{} SET {} = {} WHERE {} = {}", #tq, #table_name, #tq, #sd_db_name, #delete_value, #pk_db_name, ::sqlx_record::prelude::placeholder(1) ); sqlx::query(&query).bind(#pk_field).execute(executor).await?; Ok(()) } /// Restore a soft-deleted record pub async fn restore<'a, E>(&self, executor: E) -> Result<(), sqlx::Error> where E: sqlx::Executor<'a, Database = #db>, { Self::#restore_by_func(executor, &self.#pk_field).await } /// Restore by primary key pub async fn #restore_by_func<'a, E>(executor: E, #pk_field: &#pk_type) -> Result<(), sqlx::Error> where E: sqlx::Executor<'a, Database = #db>, { let query = format!( "UPDATE {}{}{} SET {} = {} WHERE {} = {}", #tq, #table_name, #tq, #sd_db_name, #restore_value, #pk_db_name, ::sqlx_record::prelude::placeholder(1) ); sqlx::query(&query).bind(#pk_field).execute(executor).await?; Ok(()) } /// Get the soft delete field name pub const fn soft_delete_field() -> &'static str { #sd_db_name } } } }