sqlx-record/sqlx-record-derive/src/lib.rs

1590 lines
59 KiB
Rust

mod string_utils;
extern crate proc_macro;
use proc_macro::TokenStream;
use proc_macro2::{Ident, TokenStream as TokenStream2};
use quote::{quote, format_ident};
use syn::{parse_macro_input, DeriveInput, Data, LitStr, Type, ImplGenerics, TypeGenerics, WhereClause};
use crate::string_utils::{pluralize, to_snake_case};
struct EntityField {
ident: Ident,
db_name: String,
ty: Type,
needs_type_annotation: bool,
type_override: Option<String>,
is_primary_key: bool,
is_version_field: bool,
is_soft_delete: bool,
is_created_at: bool,
is_updated_at: bool,
}
/// Parse a string attribute that can be either:
/// - `#[attr("value")]` (Meta::List style)
/// - `#[attr = "value"]` (Meta::NameValue style)
fn parse_string_attr(attr: &syn::Attribute) -> Option<String> {
match &attr.meta {
syn::Meta::List(_) => {
// #[attr("value")] style
attr.parse_args::<LitStr>().ok().map(|lit| lit.value())
}
syn::Meta::NameValue(nv) => {
// #[attr = "value"] style
if let syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit), .. }) = &nv.value {
Some(lit.value())
} else {
None
}
}
_ => None,
}
}
// Support both Update and Entity attributes
#[proc_macro_derive(Update, attributes(rename, table_name, primary_key, field_type))]
pub fn derive_update(input: TokenStream) -> TokenStream {
derive_entity_internal(input)
}
#[proc_macro_derive(Entity, attributes(rename, table_name, primary_key, version, field_type, soft_delete, created_at, updated_at))]
pub fn derive_entity(input: TokenStream) -> TokenStream {
derive_entity_internal(input)
}
/// Generate database-specific types based on features
fn db_type() -> TokenStream2 {
#[cfg(feature = "postgres")]
{
quote! { sqlx::Postgres }
}
#[cfg(feature = "sqlite")]
{
return quote! { sqlx::Sqlite };
}
#[cfg(feature = "mysql")]
{
quote! { sqlx::MySql }
}
#[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))]
{
// Default to MySql for backwards compatibility
quote! { sqlx::MySql }
}
}
fn db_arguments() -> TokenStream2 {
#[cfg(feature = "postgres")]
{
quote! { sqlx::postgres::PgArguments }
}
#[cfg(feature = "sqlite")]
{
return quote! { sqlx::sqlite::SqliteArguments<'q> };
}
#[cfg(feature = "mysql")]
{
quote! { sqlx::mysql::MySqlArguments }
}
#[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))]
{
quote! { sqlx::mysql::MySqlArguments }
}
}
/// Get table quote character
fn table_quote() -> &'static str {
#[cfg(feature = "postgres")]
{ "\"" }
#[cfg(feature = "sqlite")]
{ return "\""; }
#[cfg(feature = "mysql")]
{ "`" }
#[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))]
{ "`" }
}
/// Get compile-time placeholder for static-check SQL
fn static_placeholder(index: usize) -> String {
#[cfg(feature = "postgres")]
{ format!("${}", index) }
#[cfg(not(feature = "postgres"))]
{ let _ = index; "?".to_string() }
}
fn derive_entity_internal(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let update_form_name = format_ident!("{}UpdateForm", name);
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let table_name = get_table_name(&input);
let fields = parse_fields(&input);
let primary_key = fields.iter()
.find(|f| f.is_primary_key)
.or_else(|| fields.iter().find(|f| f.ident == "id" || f.ident == "code"))
.expect("Struct must have a primary key field, either explicitly specified or named 'id' or 'code'");
// Check for timestamp fields - either by attribute or by name
let has_created_at = fields.iter().any(|f| f.is_created_at) ||
fields.iter().any(|f| f.ident == "created_at" && matches!(&f.ty, Type::Path(p) if p.path.is_ident("i64")));
let has_updated_at = fields.iter().any(|f| f.is_updated_at) ||
fields.iter().any(|f| f.ident == "updated_at" && matches!(&f.ty, Type::Path(p) if p.path.is_ident("i64")));
let version_field = fields.iter()
.find(|f| f.is_version_field)
.or_else(|| fields.iter().find(|&f| is_version_field(f)));
// Find soft delete field (by attribute or by name convention)
// Convention: `is_active` (FALSE = deleted), `is_deleted`/`deleted` (TRUE = deleted)
let soft_delete_field = fields.iter()
.find(|f| f.is_soft_delete)
.or_else(|| fields.iter().find(|f| {
(f.ident == "is_active" || f.ident == "is_deleted" || f.ident == "deleted") &&
matches!(&f.ty, Type::Path(p) if p.path.is_ident("bool"))
}));
// Generate all implementations
let insert_impl = generate_insert_impl(&name, &table_name, primary_key, &fields, has_created_at, has_updated_at, &impl_generics, &ty_generics, &where_clause);
let get_impl = generate_get_impl(&name, &table_name, primary_key, version_field, soft_delete_field, &fields, &impl_generics, &ty_generics, &where_clause);
let update_impl = generate_update_impl(&name, &update_form_name, &table_name, &fields, primary_key, version_field, has_updated_at, &impl_generics, &ty_generics, &where_clause);
let diff_impl = generate_diff_impl(&name, &update_form_name, &fields, primary_key, version_field, &impl_generics, &ty_generics, &where_clause);
let delete_impl = generate_delete_impl(&name, &table_name, primary_key, &impl_generics, &ty_generics, &where_clause);
let soft_delete_impl = generate_soft_delete_impl(&name, &table_name, primary_key, soft_delete_field, &impl_generics, &ty_generics, &where_clause);
let pk_type = &primary_key.ty;
let pk_field_name = &primary_key.ident;
quote! {
#insert_impl
#get_impl
#update_impl
#diff_impl
#delete_impl
#soft_delete_impl
impl #impl_generics #name #ty_generics #where_clause {
pub const fn table_name() -> &'static str {
#table_name
}
pub fn entity_key(#pk_field_name: &#pk_type) -> String {
format!("/entities/{}/{}", #table_name, #pk_field_name)
}
pub fn entity_changes_table_name() -> String {
format!("entity_changes_{}", #table_name)
}
}
}.into()
}
fn get_table_name(input: &DeriveInput) -> String {
input.attrs.iter()
.find_map(|attr| {
if attr.path().is_ident("table_name") {
parse_string_attr(attr)
} else {
None
}
})
.unwrap_or_else(|| to_snake_case(&input.ident.to_string()))
}
fn parse_fields(input: &DeriveInput) -> Vec<EntityField> {
match &input.data {
Data::Struct(data_struct) => {
data_struct.fields.iter().map(|field| {
let ident = field.ident.as_ref().unwrap().clone();
let db_name = field.attrs.iter()
.find_map(|attr| {
if attr.path().is_ident("rename") {
parse_string_attr(attr)
} else {
None
}
})
.unwrap_or_else(|| ident.to_string());
let type_override = field.attrs.iter()
.find_map(|attr| {
if attr.path().is_ident("field_type") {
parse_string_attr(attr)
} else {
None
}
});
let needs_type_annotation = type_override.is_some() || {
matches!(&field.ty, syn::Type::Path(p) if {
let type_str = quote!(#p).to_string();
type_str.contains("Uuid") || type_str.contains("bool")
})
};
let is_primary_key = field.attrs.iter()
.any(|attr| attr.path().is_ident("primary_key"));
let is_version_field = field.attrs.iter()
.any(|attr| attr.path().is_ident("version"));
let is_soft_delete = field.attrs.iter()
.any(|attr| attr.path().is_ident("soft_delete"));
let is_created_at = field.attrs.iter()
.any(|attr| attr.path().is_ident("created_at"));
let is_updated_at = field.attrs.iter()
.any(|attr| attr.path().is_ident("updated_at"));
EntityField {
ident,
db_name,
ty: field.ty.clone(),
needs_type_annotation,
type_override,
is_primary_key,
is_version_field,
is_soft_delete,
is_created_at,
is_updated_at,
}
}).collect()
}
_ => panic!("Entity can only be derived for structs"),
}
}
fn is_version_field(f: &EntityField) -> bool {
f.ident == "version" && matches!(&f.ty, Type::Path(p) if p.path.is_ident("u64") ||
p.path.is_ident("u32") || p.path.is_ident("i64") || p.path.is_ident("i32"))
}
// Generate the insert implementation
fn generate_insert_impl(
name: &Ident,
table_name: &str,
primary_key: &EntityField,
fields: &[EntityField],
_has_created_at: bool,
_has_updated_at: bool,
impl_generics: &ImplGenerics,
ty_generics: &TypeGenerics,
where_clause: &Option<&WhereClause>,
) -> TokenStream2 {
let db_names: Vec<_> = fields.iter().map(|f| &f.db_name).collect();
let field_idents: Vec<_> = fields.iter().map(|f| &f.ident).collect();
let tq = table_quote();
let db = db_type();
let pk_db_name = &primary_key.db_name;
let bindings: Vec<_> = fields.iter().map(|f| {
let ident = &f.ident;
quote! { &self.#ident }
}).collect();
let pk_field = &primary_key.ident;
let pk_type = &primary_key.ty;
let field_count = db_names.len();
quote! {
impl #impl_generics #name #ty_generics #where_clause {
pub async fn insert<'a, E>(&self, executor: E) -> Result<#pk_type, sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
{
let placeholders: String = (1..=#field_count)
.map(|i| ::sqlx_record::prelude::placeholder(i))
.collect::<Vec<_>>()
.join(", ");
let insert_stmt = format!(
"INSERT INTO {}{}{} ({}) VALUES ({})",
#tq, #table_name, #tq,
vec![#(#db_names),*].join(", "),
placeholders
);
let result = sqlx::query(&insert_stmt)
#(.bind(#bindings))*
.execute(executor)
.await?;
Ok(self.#pk_field.clone())
}
/// Insert multiple entities in a single statement
pub async fn insert_many<'a, E>(executor: E, entities: &[Self]) -> Result<Vec<#pk_type>, sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
Self: Clone,
{
if entities.is_empty() {
return Ok(vec![]);
}
let field_count = #field_count;
let mut placeholders = Vec::with_capacity(entities.len());
let mut current_idx = 1usize;
for _ in entities {
let row_placeholders: String = (0..field_count)
.map(|_| {
let ph = ::sqlx_record::prelude::placeholder(current_idx);
current_idx += 1;
ph
})
.collect::<Vec<_>>()
.join(", ");
placeholders.push(format!("({})", row_placeholders));
}
let insert_stmt = format!(
"INSERT INTO {}{}{} ({}) VALUES {}",
#tq, #table_name, #tq,
vec![#(#db_names),*].join(", "),
placeholders.join(", ")
);
let mut query = sqlx::query(&insert_stmt);
for entity in entities {
#(query = query.bind(&entity.#field_idents);)*
}
query.execute(executor).await?;
Ok(entities.iter().map(|e| e.#pk_field.clone()).collect())
}
/// Insert or update on primary key conflict (upsert)
pub async fn upsert<'a, E>(&self, executor: E) -> Result<#pk_type, sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
{
let placeholders: String = (1..=#field_count)
.map(|i| ::sqlx_record::prelude::placeholder(i))
.collect::<Vec<_>>()
.join(", ");
let non_pk_fields: Vec<&str> = vec![#(#db_names),*]
.into_iter()
.filter(|f| *f != #pk_db_name)
.collect();
let upsert_stmt = ::sqlx_record::prelude::build_upsert_stmt(
#table_name,
&[#(#db_names),*],
#pk_db_name,
&non_pk_fields,
&placeholders,
);
sqlx::query(&upsert_stmt)
#(.bind(#bindings))*
.execute(executor)
.await?;
Ok(self.#pk_field.clone())
}
/// Alias for upsert
pub async fn insert_or_update<'a, E>(&self, executor: E) -> Result<#pk_type, sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
{
self.upsert(executor).await
}
}
}
}
fn get_type_string(field: &EntityField) -> String {
if field.needs_type_annotation {
if let Some(override_type) = &field.type_override {
override_type.clone()
} else {
let ty = &field.ty;
let type_str = quote!(#ty).to_string();
// Remove all whitespace
let clean_type = type_str.replace(" ", "");
if clean_type.starts_with("Option<") && clean_type.ends_with(">") {
// Extract inner type between < and >
clean_type[7..clean_type.len()-1].to_string()
} else {
type_str
}
}
} else {
String::new()
}
}
// Generate the get implementations
fn generate_get_impl(
name: &Ident,
table_name: &str,
primary_key: &EntityField,
version_field: Option<&EntityField>,
_soft_delete_field: Option<&EntityField>, // Reserved for future auto-filtering
fields: &[EntityField],
impl_generics: &ImplGenerics,
ty_generics: &TypeGenerics,
where_clause: &Option<&WhereClause>,
) -> TokenStream2 {
let select_fields = fields.iter().map(|f| {
if f.needs_type_annotation {
let type_str = get_type_string(f);
format!("{} as \"{}: {}\"", f.db_name, f.ident, type_str)
} else {
f.db_name.clone()
}
});
let pk_field_name = primary_key.ident.to_string();
let multi_pk_field_name = Some(pluralize(&pk_field_name))
.filter(|x| x != &pk_field_name)
.unwrap_or_else(|| format!("{}_list", pk_field_name));
let get_by_func = format_ident!("get_by_{}", pk_field_name);
let multi_get_by_func = format_ident!("get_by_{}", multi_pk_field_name);
let pk_field = &primary_key.ident;
let pk_type = &primary_key.ty;
let pk_field_name = primary_key.ident.to_string();
let pk_db_field_name = &primary_key.db_name;
let tq = table_quote();
let db = db_type();
let new_fields = select_fields.clone().collect::<Vec<_>>();
let select_fields_str = new_fields.iter()
.filter_map(|e| e.split(" ")
.next()).collect::<Vec<_>>().join(", ");
let select_field_list = select_fields.clone().collect::<Vec<_>>();
// Generate the get_version function if version_field exists
let get_version_impl = if let Some(vfield) = version_field {
let version_field_type = &vfield.ty;
let version_db_name = &vfield.db_name;
quote! {
pub async fn get_version<'a, E>(executor: E, #pk_field: &#pk_type) -> Result<Option<#version_field_type>, sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
{
let query = format!(
r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#,
#version_db_name,
#tq, Self::table_name(), #tq,
#pk_db_field_name,
::sqlx_record::prelude::placeholder(1)
);
let result = sqlx::query_scalar(&query)
.bind(#pk_field)
.fetch_optional(executor)
.await?;
Ok(result)
}
pub async fn get_versions<'a, E>(executor: E, keys: &Vec<#pk_type>) -> Result<::std::collections::HashMap<#pk_type, #version_field_type>, sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
{
if keys.is_empty() {
return Ok(::std::collections::HashMap::new());
}
use sqlx::Row;
let placeholders: String = (1..=keys.len())
.map(|i| ::sqlx_record::prelude::placeholder(i))
.collect::<Vec<_>>()
.join(",");
let query = format!(
r#"SELECT DISTINCT {}, {} FROM {}{}{} WHERE {} IN ({})"#,
#pk_db_field_name,
#version_db_name,
#tq, Self::table_name(), #tq,
#pk_db_field_name,
placeholders
);
let mut query_builder = sqlx::query(&query);
for key in keys {
query_builder = query_builder.bind(key);
}
let rows = query_builder
.fetch_all(executor)
.await?;
let mut result = ::std::collections::HashMap::with_capacity(rows.len());
for row in rows {
let key: #pk_type = row.try_get(0)?;
let version: #version_field_type = row.try_get(1)?;
result.insert(key, version);
}
Ok(result)
}
}
} else {
// If no version field, generate empty implementation
quote! {}
};
let field_list = fields.iter().map(|f| f.db_name.clone()).collect::<Vec<_>>();
// Check if static-check feature is enabled at macro expansion time
let use_static_validation = cfg!(feature = "static-check");
let get_by_impl = if use_static_validation {
// Static validation: use sqlx::query_as! with compile-time checked SQL
let select_stmt = format!(
r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#,
select_fields.clone().collect::<Vec<_>>().join(", "),
tq, table_name, tq, pk_db_field_name,
static_placeholder(1)
);
quote! {
pub async fn #get_by_func<'a, E>(executor: E, #pk_field: &#pk_type) -> Result<Option<Self>, sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
{
let result = sqlx::query_as!(
Self,
#select_stmt,
#pk_field
)
.fetch_optional(executor)
.await?;
Ok(result)
}
pub async fn get_by_primary_key<'a, E>(executor: E, #pk_field: &#pk_type) -> Result<Option<Self>, sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
{
let result = sqlx::query_as!(
Self,
#select_stmt,
#pk_field
)
.fetch_optional(executor)
.await?;
Ok(result)
}
}
} else {
// Runtime: use sqlx::query_as with dynamic SQL
quote! {
pub async fn #get_by_func<'a, E>(executor: E, #pk_field: &#pk_type) -> Result<Option<Self>, sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
{
let select_stmt = format!(
r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#,
vec![#(#field_list),*].join(","),
#tq, #table_name, #tq, #pk_db_field_name,
::sqlx_record::prelude::placeholder(1)
);
let result = sqlx::query_as::<_, Self>(&select_stmt)
.bind(#pk_field)
.fetch_optional(executor)
.await?;
Ok(result)
}
pub async fn get_by_primary_key<'a, E>(executor: E, #pk_field: &#pk_type) -> Result<Option<Self>, sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
{
let select_stmt = format!(
r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#,
vec![#(#field_list),*].join(","),
#tq, #table_name, #tq, #pk_db_field_name,
::sqlx_record::prelude::placeholder(1)
);
let result = sqlx::query_as::<_, Self>(&select_stmt)
.bind(#pk_field)
.fetch_optional(executor)
.await?;
Ok(result)
}
}
};
quote! {
impl #impl_generics #name #ty_generics #where_clause {
#get_by_impl
pub const fn primary_key_field() -> &'static str {
#pk_field_name
}
pub const fn primary_key_db_field() -> &'static str {
#pk_db_field_name
}
pub fn primary_key(&self) -> &#pk_type {
&self.#pk_field
}
pub fn select_fields() -> Vec<&'static str> {
vec![#(#select_field_list),*]
}
pub async fn #multi_get_by_func<'a, E>(executor: E, ids: &[#pk_type]) -> Result<Vec<Self>, sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
{
if ids.is_empty() {
return Ok(vec![]);
}
let placeholders: String = (1..=ids.len())
.map(|i| ::sqlx_record::prelude::placeholder(i))
.collect::<Vec<_>>()
.join(",");
let query = format!(
r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} IN ({})"#,
#select_fields_str,
#tq, Self::table_name(), #tq,
#pk_db_field_name,
placeholders
);
let mut q = sqlx::query_as::<_, Self>(&query);
for id in ids {
q = q.bind(id);
}
q.fetch_all(executor).await
}
#get_version_impl
pub async fn find<'a, E>(
executor: E,
filters: Vec<::sqlx_record::prelude::Filter<'a>>,
index: Option<&str>,
) -> Result<Vec<Self>, sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
{
Self::find_ordered_with_limit(executor, filters, index, Vec::new(), None).await
}
pub async fn find_ordered<'a, E>(
executor: E,
filters: Vec<::sqlx_record::prelude::Filter<'a>>,
index: Option<&str>,
order_by: Vec<(&str, bool)>,
) -> Result<Vec<Self>, sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
{
Self::find_ordered_with_limit(executor, filters, index, order_by, None).await
}
pub async fn find_one<'a, E>(
executor: E,
filters: Vec<::sqlx_record::prelude::Filter<'a>>,
index: Option<&str>,
) -> Result<Option<Self>, sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
{
let found = Self::find_ordered_with_limit(executor, filters, index, Vec::new(), Some((0, 1))).await?;
Ok(found.into_iter().next())
}
pub async fn find_ordered_with_limit<'a, E>(
executor: E,
filters: Vec<::sqlx_record::prelude::Filter<'a>>,
index: Option<&str>,
order_by: Vec<(&str, bool)>,
offset_limit: Option<(u32, u32)>,
) -> Result<Vec<Self>, sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
{
use ::sqlx_record::prelude::{Filter, Value, bind_as_values};
let (where_conditions, values) = Filter::build_where_clause(&filters);
let where_clause = if !where_conditions.is_empty() {
format!("WHERE {}", where_conditions)
} else {
String::new()
};
let index_clause = ::sqlx_record::prelude::build_index_clause(index);
//Filter order_by fields to only those managed
let fields = Self::select_fields().into_iter().collect::<::std::collections::HashSet<_>>();
let order_by = order_by.iter()
.filter(|(field, _)| fields.contains(field))
.collect::<Vec<_>>();
let order_by_clause = if !order_by.is_empty() {
let order_by_str = order_by.iter()
.map(|(field, asc)| format!("{} {}", field, if *asc { "ASC" } else { "DESC" }))
.collect::<Vec<_>>()
.join(", ");
format!("ORDER BY {}", order_by_str)
} else {
String::new()
};
let query = format!(
r#"SELECT DISTINCT {} FROM {}{}{} {} {} {} {}"#,
#select_fields_str,
#tq, #table_name, #tq,
index_clause,
where_clause,
order_by_clause,
offset_limit.map(|(offset, limit)| format!("LIMIT {} OFFSET {}", limit, offset)).unwrap_or_default(),
);
let db_query = sqlx::query_as(&query);
// Bind values to the query
let results = match bind_as_values(db_query, &values)
.fetch_all(executor)
.await {
Ok(results) => results,
Err(err) => {
tracing::error!(r#"Error executing
Query: {}
Error => {:?}
"#, query, err);
return Err(err);
}
};
Ok(results)
}
pub async fn count<'a, E>(
executor: E,
filters: Vec<::sqlx_record::prelude::Filter<'a>>,
index: Option<&str>,
) -> Result<u64, sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
{
use ::sqlx_record::prelude::{Filter, Value, bind_scalar_values};
let (where_conditions, values) = Filter::build_where_clause(&filters);
let where_clause = if !where_conditions.is_empty() {
format!("WHERE {}", where_conditions)
} else {
String::new()
};
let index_clause = ::sqlx_record::prelude::build_index_clause(index);
let count_expr = ::sqlx_record::prelude::build_count_expr(#pk_db_field_name);
let query = format!(
r#"SELECT {} FROM {}{}{} {} {}"#,
count_expr,
#tq, #table_name, #tq,
index_clause,
where_clause,
);
let db_query = sqlx::query_scalar::<_, i64>(&query);
// Bind values to the query
let count = match bind_scalar_values(db_query, &values)
.fetch_optional(executor)
.await {
Ok(count) => count.unwrap_or(0) as u64,
Err(err) => {
tracing::error!(r#"Error executing
Query: {}
Error => {:?}
"#, query, err);
return Err(err);
}
};
Ok(count)
}
/// Paginate results with total count
pub async fn paginate<'a, E>(
executor: E,
filters: Vec<::sqlx_record::prelude::Filter<'a>>,
index: Option<&str>,
order_by: Vec<(&str, bool)>,
page_request: ::sqlx_record::prelude::PageRequest,
) -> Result<::sqlx_record::prelude::Page<Self>, sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db> + Copy,
{
// Get total count first
let total_count = Self::count(executor, filters.clone(), index).await?;
// Get page items
let items = Self::find_ordered_with_limit(
executor,
filters,
index,
order_by,
Some((page_request.offset(), page_request.limit())),
).await?;
Ok(::sqlx_record::prelude::Page::new(
items,
total_count,
page_request.page,
page_request.page_size,
))
}
/// Select specific fields only (returns raw rows)
/// Use `sqlx::Row` trait to access fields: `row.try_get::<String, _>("name")?`
pub async fn find_partial<'a, E>(
executor: E,
select_fields: &[&str],
filters: Vec<::sqlx_record::prelude::Filter<'a>>,
index: Option<&str>,
) -> Result<Vec<<#db as sqlx::Database>::Row>, sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
{
use ::sqlx_record::prelude::{Filter, bind_values};
// Validate fields exist
let valid_fields: ::std::collections::HashSet<_> = Self::select_fields().into_iter().collect();
let selected: Vec<_> = select_fields.iter()
.filter(|f| valid_fields.contains(*f))
.copied()
.collect();
if selected.is_empty() {
return Ok(vec![]);
}
let (where_conditions, values) = Filter::build_where_clause(&filters);
let where_clause = if !where_conditions.is_empty() {
format!("WHERE {}", where_conditions)
} else {
String::new()
};
let index_clause = ::sqlx_record::prelude::build_index_clause(index);
let query = format!(
"SELECT DISTINCT {} FROM {}{}{} {} {}",
selected.join(", "),
#tq, #table_name, #tq,
index_clause,
where_clause,
);
let db_query = sqlx::query(&query);
bind_values(db_query, &values)
.fetch_all(executor)
.await
}
}
}
}
/// Check if a type is a binary type (Vec<u8>)
fn is_binary_type(ty: &Type) -> bool {
if let Type::Path(type_path) = ty {
let path_str = quote!(#type_path).to_string().replace(" ", "");
path_str == "Vec<u8>" || path_str == "std::vec::Vec<u8>"
} else {
false
}
}
fn generate_update_impl(
name: &Ident,
update_form_name: &Ident,
table_name: &str,
fields: &[EntityField],
primary_key: &EntityField,
version_field: Option<&EntityField>,
has_updated_at: bool,
impl_generics: &ImplGenerics,
ty_generics: &TypeGenerics,
where_clause: &Option<&WhereClause>,
) -> TokenStream2 {
let update_fields: Vec<_> = fields.iter()
.filter(|f| f.ident != primary_key.ident && f.ident != "created_at" &&
version_field.as_ref().map(|vf| f.ident != vf.ident).unwrap_or(true))
.collect();
let field_idents: Vec<_> = update_fields.iter().map(|f| &f.ident).collect();
let field_types: Vec<_> = update_fields.iter().map(|f| &f.ty).collect();
let db_names: Vec<_> = update_fields.iter().map(|f| &f.db_name).collect();
let db = db_type();
let db_args = db_arguments();
let setter_methods: Vec<_> = update_fields.iter().map(|field| {
let method_name = format_ident!("set_{}", field.ident);
let field_type = &field.ty;
let field_ident = &field.ident;
quote! {
pub fn #method_name<T>(&mut self, value: T) -> ()
where
T: Into<#field_type>,
{
self.#field_ident = Some(value.into());
}
}
}).collect();
let builder_methods = update_fields.iter().map(|field| {
let method_name = format_ident!("with_{}", field.ident);
let field_type = &field.ty;
let field_ident = &field.ident;
quote! {
pub fn #method_name<T>(mut self, value: T) -> Self
where
T: Into<#field_type>,
{
self.#field_ident = Some(value.into());
self
}
}
});
// Generate eval_* methods for non-binary fields
let eval_methods: Vec<_> = update_fields.iter()
.filter(|f| !is_binary_type(&f.ty))
.map(|field| {
let method_name = format_ident!("eval_{}", field.ident);
let db_name = &field.db_name;
quote! {
/// Set field using an UpdateExpr for complex operations (arithmetic, CASE/WHEN, etc.)
/// Takes precedence over with_* if both are set.
pub fn #method_name(mut self, expr: ::sqlx_record::prelude::UpdateExpr) -> Self {
self._exprs.insert(#db_name, expr);
self
}
}
})
.collect();
// Version increment - use CASE WHEN for cross-database compatibility
let version_increment = if let Some(vfield) = version_field {
let version_db_name = &vfield.db_name;
let version_type = &vfield.ty;
// Handle different integer types with appropriate max values for wrapping
let max_val = match version_type {
Type::Path(type_path) if type_path.path.is_ident("u32") => "4294967295",
Type::Path(type_path) if type_path.path.is_ident("u64") => "18446744073709551615",
Type::Path(type_path) if type_path.path.is_ident("i32") => "2147483647",
Type::Path(type_path) if type_path.path.is_ident("i64") => "9223372036854775807",
_ => "",
};
if max_val.is_empty() {
quote! {
parts.push(format!("{} = {} + 1", #version_db_name, #version_db_name));
}
} else {
quote! {
parts.push(format!("{} = CASE WHEN {} = {} THEN 0 ELSE {} + 1 END", #version_db_name, #version_db_name, #max_val, #version_db_name));
}
}
} else {
quote! {}
};
// Auto-update updated_at timestamp (only if not manually set)
let updated_at_increment = if has_updated_at {
quote! {
// Only auto-set updated_at if not already set in form or via expression
if self.updated_at.is_none() && !self._exprs.contains_key("updated_at") {
parts.push(format!("updated_at = {}", ::sqlx_record::prelude::placeholder(idx)));
values.push(::sqlx_record::prelude::Value::Int64(chrono::Utc::now().timestamp_millis()));
idx += 1;
}
}
} else {
quote! {}
};
quote! {
/// Update form with support for simple value updates and complex expressions
pub struct #update_form_name #ty_generics #where_clause {
#(pub #field_idents: Option<#field_types>,)*
/// Expression-based updates (eval_* methods). Takes precedence over with_* for same field.
pub _exprs: std::collections::HashMap<&'static str, ::sqlx_record::prelude::UpdateExpr>,
}
impl #impl_generics Default for #update_form_name #ty_generics #where_clause {
fn default() -> Self {
Self {
#(#field_idents: None,)*
_exprs: std::collections::HashMap::new(),
}
}
}
impl #impl_generics #name #ty_generics #where_clause {
pub fn update_form() -> #update_form_name {
#update_form_name::default()
}
}
impl #impl_generics #update_form_name #ty_generics #where_clause {
pub fn new() -> Self {
Self::default()
}
#(#setter_methods)*
#(#builder_methods)*
#(#eval_methods)*
/// Raw SQL expression escape hatch for any field (no bind parameters).
/// For expressions with parameters, use `raw_with_values()`.
pub fn raw(mut self, field: &'static str, sql: impl Into<String>) -> Self {
self._exprs.insert(field, ::sqlx_record::prelude::UpdateExpr::Raw {
sql: sql.into(),
values: vec![],
});
self
}
/// Raw SQL expression escape hatch with bind parameters.
/// Use `?` for placeholders - they will be converted to proper database placeholders.
pub fn raw_with_values(mut self, field: &'static str, sql: impl Into<String>, values: Vec<::sqlx_record::prelude::Value>) -> Self {
self._exprs.insert(field, ::sqlx_record::prelude::UpdateExpr::Raw {
sql: sql.into(),
values,
});
self
}
/// Generate UPDATE SET clause and collect values for binding.
/// Expression fields (eval_*) take precedence over simple value fields (with_*).
pub fn update_stmt_with_values(&self) -> (String, Vec<::sqlx_record::prelude::Value>) {
let mut parts = Vec::new();
let mut values = Vec::new();
let mut idx = 1usize;
#(
// Check if this field has an expression (takes precedence)
if let Some(expr) = self._exprs.get(#db_names) {
let (sql, expr_values) = expr.build_sql(#db_names, idx);
parts.push(format!("{} = {}", #db_names, sql));
idx += expr_values.len();
values.extend(expr_values);
} else if let Some(ref value) = self.#field_idents {
parts.push(format!("{} = {}", #db_names, ::sqlx_record::prelude::placeholder(idx)));
values.push(::sqlx_record::prelude::Value::from(value.clone()));
idx += 1;
}
)*
#version_increment
#updated_at_increment
(parts.join(", "), values)
}
/// Generate UPDATE SET clause (backward compatible, without values).
/// For new code, prefer `update_stmt_with_values()`.
pub fn update_stmt(&self) -> String {
self.update_stmt_with_values().0
}
/// Bind all form values to query in correct order.
/// Handles both simple values and expression values, respecting expression precedence.
/// Uses Value enum for proper type handling of Option<T> fields.
pub fn bind_all_values<'q>(&'q self, mut query: sqlx::query::Query<'q, #db, #db_args>)
-> sqlx::query::Query<'q, #db, #db_args>
{
// Use update_stmt_with_values to get properly converted values
// This handles nested Options (Option<Option<T>>) correctly
let (_, values) = self.update_stmt_with_values();
for value in values {
query = ::sqlx_record::prelude::bind_value_owned(query, value);
}
query
}
/// Legacy binding method - binds values through the Value enum for proper type handling.
/// For backward compatibility. New code should use bind_all_values().
pub fn bind_form_values<'q>(&'q self, mut query: sqlx::query::Query<'q, #db, #db_args>)
-> sqlx::query::Query<'q, #db, #db_args>
{
// Always use Value-based binding to properly handle Option<T> fields
// This ensures nested Options (Option<Option<T>>) are unwrapped correctly
let (_, values) = self.update_stmt_with_values();
for value in values {
query = ::sqlx_record::prelude::bind_value_owned(query, value);
}
query
}
/// Check if this form uses any expressions
pub fn has_expressions(&self) -> bool {
!self._exprs.is_empty()
}
/// Get the number of bind parameters this form will use.
pub fn param_count(&self) -> usize {
self.update_stmt_with_values().1.len()
}
pub const fn table_name() -> &'static str {
#table_name
}
}
}
}
fn generate_diff_impl(
name: &Ident,
update_form_name: &Ident,
fields: &[EntityField],
primary_key: &EntityField,
version_field: Option<&EntityField>,
impl_generics: &ImplGenerics,
ty_generics: &TypeGenerics,
where_clause: &Option<&WhereClause>,
) -> TokenStream2 {
let update_fields: Vec<_> = fields.iter()
.filter(|f| f.ident != primary_key.ident && f.ident != "created_at" &&
version_field.as_ref().map(|vf| f.ident != vf.ident).unwrap_or(true))
.collect();
let field_idents: Vec<_> = update_fields.iter().map(|f| &f.ident).collect();
let db_names: Vec<_> = update_fields.iter().map(|f| &f.db_name).collect();
let field_types: Vec<_> = update_fields.iter().map(|f| &f.ty).collect();
let bind_form_values_trait = format_ident!("{}BindFormValues", name);
let bind_form_values_func = format_ident!("{}", to_snake_case(&name.to_string()));
let pk_field = primary_key.ident.clone();
let pk_type = primary_key.ty.clone();
let pk_db_name = primary_key.db_name.clone();
let multi_pk_field= format_ident!("{}", pluralize(&pk_field.to_string()));
let update_by_func = format_ident!("update_by_{}", pk_field);
let multi_update_by_func = format_ident!("update_by_{}", multi_pk_field);
let db = db_type();
let db_args = db_arguments();
let tq = table_quote();
quote! {
impl #impl_generics #update_form_name #ty_generics #where_clause {
pub fn model_diff(&self, other: &#name #ty_generics) -> serde_json::Value {
let mut changes = serde_json::Map::new();
#(
if let Some(ref value) = &self.#field_idents {
if &other.#field_idents != value {
changes.insert(#db_names.to_string(), serde_json::json!(value));
}
}
)*
serde_json::Value::Object(changes)
}
pub fn diff_modify(&mut self, model: &#name #ty_generics) -> serde_json::Value {
let mut changes = serde_json::Map::new();
#(
if let Some(ref value) = self.#field_idents {
if &model.#field_idents == value {
self.#field_idents = None;
} else {
changes.insert(#db_names.to_string(), serde_json::json!(value));
}
}
)*
serde_json::Value::Object(changes)
}
pub async fn db_diff<'q, E>(&self, #pk_field: &#pk_type, executor: E) -> Result<serde_json::Value, sqlx::Error>
where
E: sqlx::Executor<'q, Database=#db>,
{
use sqlx::Row;
let fields_to_fetch: Vec<String> = vec![
#(
if self.#field_idents.is_some() {
#db_names.to_string()
} else {
String::new()
}
),*
].into_iter().filter(|s| !s.is_empty()).collect();
if fields_to_fetch.is_empty() {
return Ok(serde_json::json!({}));
}
let query = format!(
"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}",
fields_to_fetch.join(", "),
#tq, Self::table_name(), #tq, #pk_db_name,
::sqlx_record::prelude::placeholder(1)
);
let row = sqlx::query(&query)
.bind(#pk_field)
.fetch_one(executor)
.await?;
let mut changes = serde_json::Map::new();
#(
if let Some(value) = &self.#field_idents {
if fields_to_fetch.contains(&#db_names.to_string()) {
let db_value: #field_types = row.try_get(#db_names)?;
if *value != db_value {
changes.insert(#db_names.to_string(), serde_json::json!(value));
}
}
}
)*
Ok(serde_json::Value::Object(changes))
}
}
impl #impl_generics #name #ty_generics #where_clause {
pub fn to_update_form(&self) -> #update_form_name #ty_generics {
#update_form_name {
#(#field_idents: Some(self.#field_idents.clone()),)*
_exprs: std::collections::HashMap::new(),
}
}
pub fn initial_diff(&self) -> serde_json::Value {
let mut map = serde_json::Map::new();
#(
map.insert(#db_names.to_string(), serde_json::to_value(&self.#field_idents).unwrap_or(serde_json::Value::Null));
)*
serde_json::Value::Object(map)
}
}
pub trait #bind_form_values_trait<'q, 'f>
where 'f: 'q {
fn #bind_form_values_func(self, form: &'f #update_form_name) -> Self;
}
impl<'q, 'f> #bind_form_values_trait<'q, 'f> for sqlx::query::Query<'q, #db, #db_args>
where 'f: 'q {
fn #bind_form_values_func(self, form: &'f #update_form_name) -> Self {
form.bind_form_values(self)
}
}
impl #impl_generics #name #ty_generics #where_clause {
pub async fn update<'a, E>(&self, executor: E, form: #update_form_name) -> Result<(), sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
{
// Count parameters in the update statement
let update_stmt = form.update_stmt();
let param_count = update_stmt.matches(::sqlx_record::prelude::placeholder(1).chars().next().unwrap_or('?')).count();
let query_str = format!(
r#"UPDATE {}{}{} SET {} WHERE {} = {}"#,
#tq, Self::table_name(), #tq,
update_stmt,
#pk_db_name,
::sqlx_record::prelude::placeholder(param_count + 1)
);
let _q = match sqlx::query(&query_str)
.#bind_form_values_func(&form)
.bind(&self.#pk_field)
.execute(executor)
.await {
Ok(q) => q,
Err(err) => {
tracing::error!(r#"Error updating entity: {:?}
Query String: {}
"#, err, query_str);
return Err(err);
}
};
Ok(())
}
pub async fn #update_by_func<'a, E>(executor: E, #pk_field: &#pk_type, form: #update_form_name) -> Result<(), sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
{
// Count parameters in the update statement
let update_stmt = form.update_stmt();
let param_count = update_stmt.matches(::sqlx_record::prelude::placeholder(1).chars().next().unwrap_or('?')).count();
let query_str = format!(
r#"UPDATE {}{}{} SET {} WHERE {} = {}"#,
#tq, Self::table_name(), #tq,
update_stmt,
#pk_db_name,
::sqlx_record::prelude::placeholder(param_count + 1)
);
let _q = match sqlx::query(&query_str)
.#bind_form_values_func(&form)
.bind(#pk_field)
.execute(executor)
.await {
Ok(q) => q,
Err(err) => {
tracing::error!(r#"Error updating entity: {:?}
Query String: {}
"#, err, query_str);
return Err(err);
}
};
Ok(())
}
pub async fn #multi_update_by_func<'a, E>(executor: E, #multi_pk_field: &Vec<#pk_type>, form: #update_form_name) -> Result<(), sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
{
if #multi_pk_field.is_empty() {
return Ok(());
}
let update_stmt = form.update_stmt();
let form_param_count = update_stmt.matches(::sqlx_record::prelude::placeholder(1).chars().next().unwrap_or('?')).count();
let placeholders: String = (1..=#multi_pk_field.len())
.map(|i| ::sqlx_record::prelude::placeholder(form_param_count + i))
.collect::<Vec<_>>()
.join(",");
let query_str = format!(
r#"UPDATE {}{}{} SET {} WHERE {} IN ({})"#,
#tq, Self::table_name(), #tq,
update_stmt,
#pk_db_name, placeholders,
);
let mut q = sqlx::query(&query_str)
.#bind_form_values_func(&form);
for #pk_field in #multi_pk_field {
q = q.bind(#pk_field);
}
match q.execute(executor)
.await {
Ok(_q) => (),
Err(err) => {
tracing::error!(r#"Error updating entity: {:?}
Query String: {}
"#, err, query_str);
return Err(err);
}
};
Ok(())
}
/// Update all records matching the filter conditions
/// Returns the number of affected rows
pub async fn update_by_filter<'a, E>(
executor: E,
filters: Vec<::sqlx_record::prelude::Filter<'a>>,
form: #update_form_name,
) -> Result<u64, sqlx::Error>
where
E: sqlx::Executor<'a, Database=#db>,
{
use ::sqlx_record::prelude::{Filter, bind_values};
if filters.is_empty() {
// Require at least one filter to prevent accidental table-wide updates
return Err(sqlx::Error::Protocol(
"update_by_filter requires at least one filter to prevent accidental table-wide updates".to_string()
));
}
let (update_stmt, form_values) = form.update_stmt_with_values();
if update_stmt.is_empty() {
return Ok(0);
}
let form_param_count = form_values.len();
let (where_conditions, filter_values) = Filter::build_where_clause_with_offset(&filters, form_param_count + 1);
let query_str = format!(
r#"UPDATE {}{}{} SET {} WHERE {}"#,
#tq, Self::table_name(), #tq,
update_stmt,
where_conditions,
);
// Combine form values and filter values
let mut all_values = form_values;
all_values.extend(filter_values);
let query = sqlx::query(&query_str);
let result = bind_values(query, &all_values)
.execute(executor)
.await?;
Ok(result.rows_affected())
}
}
}
}
// Generate delete implementation - always generated for ALL entities
fn generate_delete_impl(
name: &Ident,
table_name: &str,
primary_key: &EntityField,
impl_generics: &ImplGenerics,
ty_generics: &TypeGenerics,
where_clause: &Option<&WhereClause>,
) -> TokenStream2 {
let pk_field = &primary_key.ident;
let pk_type = &primary_key.ty;
let pk_db_name = &primary_key.db_name;
let db = db_type();
let tq = table_quote();
let pk_field_name = primary_key.ident.to_string();
let hard_delete_by_func = format_ident!("hard_delete_by_{}", pk_field_name);
quote! {
impl #impl_generics #name #ty_generics #where_clause {
/// Hard delete - permanently removes the row from database
pub async fn hard_delete<'a, E>(&self, executor: E) -> Result<(), sqlx::Error>
where
E: sqlx::Executor<'a, Database = #db>,
{
Self::#hard_delete_by_func(executor, &self.#pk_field).await
}
/// Hard delete by primary key - permanently removes the row from database
pub async fn #hard_delete_by_func<'a, E>(executor: E, #pk_field: &#pk_type) -> Result<(), sqlx::Error>
where
E: sqlx::Executor<'a, Database = #db>,
{
let query = format!(
"DELETE FROM {}{}{} WHERE {} = {}",
#tq, #table_name, #tq,
#pk_db_name, ::sqlx_record::prelude::placeholder(1)
);
sqlx::query(&query).bind(#pk_field).execute(executor).await?;
Ok(())
}
}
}
}
// Generate soft delete implementation
fn generate_soft_delete_impl(
name: &Ident,
table_name: &str,
primary_key: &EntityField,
soft_delete_field: Option<&EntityField>,
impl_generics: &ImplGenerics,
ty_generics: &TypeGenerics,
where_clause: &Option<&WhereClause>,
) -> TokenStream2 {
let Some(sd_field) = soft_delete_field else {
return quote! {};
};
let pk_field = &primary_key.ident;
let pk_type = &primary_key.ty;
let pk_db_name = &primary_key.db_name;
let sd_db_name = &sd_field.db_name;
let db = db_type();
let tq = table_quote();
let pk_field_name = primary_key.ident.to_string();
let soft_delete_by_func = format_ident!("soft_delete_by_{}", pk_field_name);
let restore_by_func = format_ident!("restore_by_{}", pk_field_name);
// Determine semantics based on field name and attribute:
// - #[soft_delete] attribute: field should be FALSE when deleted (user convention)
// - `is_active` by name: FALSE when deleted, TRUE when active
// - `is_deleted`/`deleted` by name: TRUE when deleted, FALSE when active
let sd_field_name = sd_field.ident.to_string();
let is_inverted = sd_field.is_soft_delete || sd_field_name == "is_active";
let (delete_value, restore_value) = if is_inverted {
("FALSE", "TRUE")
} else {
("TRUE", "FALSE")
};
quote! {
impl #impl_generics #name #ty_generics #where_clause {
/// Soft delete - marks record as deleted without removing from database
pub async fn soft_delete<'a, E>(&self, executor: E) -> Result<(), sqlx::Error>
where
E: sqlx::Executor<'a, Database = #db>,
{
Self::#soft_delete_by_func(executor, &self.#pk_field).await
}
/// Soft delete by primary key
pub async fn #soft_delete_by_func<'a, E>(executor: E, #pk_field: &#pk_type) -> Result<(), sqlx::Error>
where
E: sqlx::Executor<'a, Database = #db>,
{
let query = format!(
"UPDATE {}{}{} SET {} = {} WHERE {} = {}",
#tq, #table_name, #tq,
#sd_db_name, #delete_value,
#pk_db_name, ::sqlx_record::prelude::placeholder(1)
);
sqlx::query(&query).bind(#pk_field).execute(executor).await?;
Ok(())
}
/// Restore a soft-deleted record
pub async fn restore<'a, E>(&self, executor: E) -> Result<(), sqlx::Error>
where
E: sqlx::Executor<'a, Database = #db>,
{
Self::#restore_by_func(executor, &self.#pk_field).await
}
/// Restore by primary key
pub async fn #restore_by_func<'a, E>(executor: E, #pk_field: &#pk_type) -> Result<(), sqlx::Error>
where
E: sqlx::Executor<'a, Database = #db>,
{
let query = format!(
"UPDATE {}{}{} SET {} = {} WHERE {} = {}",
#tq, #table_name, #tq,
#sd_db_name, #restore_value,
#pk_db_name, ::sqlx_record::prelude::placeholder(1)
);
sqlx::query(&query).bind(#pk_field).execute(executor).await?;
Ok(())
}
/// Get the soft delete field name
pub const fn soft_delete_field() -> &'static str {
#sd_db_name
}
}
}
}