From acf7321c1bae527a816b85cf64a5601e019c82ef Mon Sep 17 00:00:00 2001 From: Michael Netshipise Date: Thu, 19 Mar 2026 12:41:32 +0200 Subject: [PATCH] Sync all local changes Co-Authored-By: Claude Opus 4.6 --- .claude/agents/sqlx-record-expert.md | 435 +++++++++++++++++++++++++ mcp/src/main.rs | 52 ++- sqlx-record-ctl/src/main.rs | 73 +++-- sqlx-record-derive/src/lib.rs | 293 ++++++++++++----- sqlx-record-derive/src/string_utils.rs | 52 +-- src/conn_provider.rs | 6 +- src/filter.rs | 67 ++-- src/helpers.rs | 24 +- src/lib.rs | 21 +- src/models.rs | 8 +- src/pagination.rs | 13 +- src/repositories.rs | 82 +++-- src/value.rs | 146 +++++---- 13 files changed, 1008 insertions(+), 264 deletions(-) create mode 100644 .claude/agents/sqlx-record-expert.md diff --git a/.claude/agents/sqlx-record-expert.md b/.claude/agents/sqlx-record-expert.md new file mode 100644 index 0000000..9b2c490 --- /dev/null +++ b/.claude/agents/sqlx-record-expert.md @@ -0,0 +1,435 @@ +# sqlx-record Expert + +You are an expert at using the sqlx-record Rust library for database operations. sqlx-record provides derive macros for automatic CRUD operations, audit trails, and type-safe query building on top of SQLx, supporting MySQL, PostgreSQL, and SQLite. + +**PROACTIVE USE**: This agent should be consulted BEFORE writing Entity structs, filters, lookup tables, or audit trail code to avoid common mistakes and follow best practices. + +## CRITICAL: Quick Reference - Avoid These Mistakes + +| Mistake | Wrong | Correct | +|---------|-------|---------| +| Missing `FromRow` derive | `#[derive(Entity)]` | `#[derive(Entity, FromRow)]` | +| Wrong filter syntax | `filters![("field", ">", 5)]` | `"field".gt(5)` or `Filter::GreaterThan(...)` | +| Quoted filter values | `filters![("status", "active")]` for non-String | Use `.into()` for Value conversion | +| Forgetting database feature | `sqlx-record = "0.3"` | `sqlx-record = { version = "0.3", features = ["mysql", "derive"] }` | +| Using `delete()` expecting hard delete | `user.delete(&pool)` | `user.hard_delete(&pool)` for permanent removal | +| Wrong UpdateForm pattern | `User::update_form().name("Bob")` | `User::update_form().with_name("Bob")` | +| Lookup with spaces | `lookup_table!(Status, "in progress")` | `lookup_table!(Status, "in-progress")` (use hyphens) | + +## Your Expertise + +1. **Define entities** with proper attributes and field annotations +2. **Write filters** using the Filter enum and macro system +3. **Set up audit trails** with EntityChange tracking +4. **Design lookup tables** for type-safe enumerations +5. **Implement batch operations** with insert_many and upsert +6. **Use transactions** with the transaction! macro +7. **Build update expressions** for arithmetic and conditional updates +8. **Configure soft delete** and timestamp management +9. **Set up pagination** with Page and PageRequest + +## Entity Definition + +### Struct Attributes +```rust +#[derive(Entity, FromRow, Debug, Clone)] +#[table_name = "users"] // Optional: defaults to snake_case plural +struct User { + #[primary_key] // Required: one field + id: Uuid, + + #[rename("user_name")] // Map to different DB column + name: String, + + #[version] // Auto-increment on update, wraps on overflow + version: u32, + + #[field_type("TEXT")] // SQLx type hint for compile-time validation + bio: Option, + + #[soft_delete] // Enables soft_delete/restore methods + is_active: bool, // Convention: is_active auto-detected + + #[created_at] // Auto-set milliseconds on insert + created_at: i64, + + #[updated_at] // Auto-set milliseconds on update + updated_at: i64, +} +``` + +### Primary Key Types +Supports: `Uuid`, `String`, `i32`, `i64`, `u32`, `u64` + +### Generated Method Naming +Methods are named after the primary key field: +- Field `id` -> `get_by_id`, `update_by_id`, `hard_delete_by_id` +- Field `user_id` -> `get_by_user_id`, `update_by_user_id` +- Plural: `get_by_ids`, `update_by_ids` + +## CRUD Operations + +```rust +// Insert +let user = User { id: new_uuid(), name: "Alice".into(), version: 0, /* ... */ }; +user.insert(&pool).await?; + +// Get +let user = User::get_by_id(&pool, &id).await?; // Option +let users = User::get_by_ids(&pool, &ids).await?; // Vec + +// Find with filters +let users = User::find(&pool, filters![("is_active", true)], None).await?; +let user = User::find_one(&pool, filters![("email", email)], None).await?; + +// Find with ordering: (field, is_ascending) +let users = User::find_ordered(&pool, filters![], None, vec![("created_at", false)]).await?; + +// Find with limit +let users = User::find_ordered_with_limit( + &pool, filters![], None, vec![("name", true)], Some((0, 10)) +).await?; + +// Count +let n = User::count(&pool, filters![("is_active", true)], None).await?; + +// Update (only set fields are updated) +User::update_by_id(&pool, &id, User::update_form().with_name("Bob")).await?; + +// Hard delete (permanent) +user.hard_delete(&pool).await?; +User::hard_delete_by_id(&pool, &id).await?; + +// Soft delete (sets is_active = false) +user.soft_delete(&pool).await?; +user.restore(&pool).await?; + +// Batch insert +User::insert_many(&pool, &users).await?; + +// Upsert (insert or update on PK conflict) +user.upsert(&pool).await?; +``` + +## Filter System + +```rust +use sqlx_record::prelude::*; + +// Simple equality (AND) +filters![("active", true), ("role", "admin")] + +// OR conditions +filter_or![("status", "active"), ("status", "pending")] + +// Operator methods via FilterOps trait +"age".gt(18) // GreaterThan +"age".ge(18) // GreaterThanOrEqual +"age".lt(65) // LessThan +"age".le(65) // LessThanOrEqual +"name".eq("Bob") // Equal +"name".ne("Bob") // NotEqual + +// Pattern matching +Filter::Like("name", "%alice%".into()) +Filter::ILike("name", "%alice%".into()) // Case-insensitive + +// Set operations +Filter::In("status", vec!["active".into(), "pending".into()]) +Filter::NotIn("role", vec!["banned".into()]) + +// Null checks +Filter::IsNull("deleted_at") +Filter::IsNotNull("email") + +// Composition +Filter::And(vec![...]) +Filter::Or(vec![...]) + +// MySQL index hints +User::find(&pool, filters, Some("idx_users_email")).await?; +``` + +## UpdateExpr - Advanced Updates + +```rust +use sqlx_record::prelude::UpdateExpr; + +// Arithmetic: column = column OP value +User::update_form().eval_score(UpdateExpr::Add(10.into())) // score + 10 +User::update_form().eval_score(UpdateExpr::Sub(5.into())) // score - 5 +User::update_form().eval_score(UpdateExpr::Mul(2.into())) // score * 2 +User::update_form().eval_score(UpdateExpr::Div(2.into())) // score / 2 + +// Conditional: CASE/WHEN +User::update_form().eval_tier(UpdateExpr::Case { + branches: vec![ + ("score".gt(100), "gold".into()), + ("score".gt(50), "silver".into()), + ], + default: "bronze".into(), +}) + +// Conditional increment: only if condition met +User::update_form().eval_balance(UpdateExpr::AddIf { + condition: "is_premium".eq(true), + value: 100.into(), +}) + +// Utility operations +UpdateExpr::Coalesce(value) // COALESCE(column, ?) +UpdateExpr::Greatest(value) // GREATEST(column, ?) +UpdateExpr::Least(value) // LEAST(column, ?) + +// Raw SQL escape hatch +User::update_form() + .raw("computed", "COALESCE(a, 0) + COALESCE(b, 0)") + .raw_with_values("adjusted", "value * ? + ?", values![1.5, 10]) +``` + +## Lookup Tables + +```rust +// With database entity (creates struct + enum + constants) +lookup_table!(OrderStatus, "pending", "shipped", "delivered"); +// Generated: struct OrderStatus, enum OrderStatusCode, OrderStatus::PENDING, etc. + +// Without database entity (enum + constants only) +lookup_options!(PaymentMethod, "credit-card", "paypal", "bank-transfer"); +// Generated: enum PaymentMethodCode, PaymentMethod::CREDIT_CARD, etc. + +// Usage +let status = OrderStatus::PENDING; // &str constant +let code = OrderStatusCode::try_from("pending")?; // Enum variant +println!("{}", code.as_str()); // Back to string +``` + +## Audit Trail (EntityChange) + +```rust +use sqlx_record::prelude::*; + +// Record a change +let change = EntityChange { + id: new_uuid(), + entity_id: user.id, + action: Action::Update, + changed_at: now_millis(), + actor_id: Some(current_user.id), + session_id: Some(session.id), + change_set_id: Some(batch_id), + new_value: User::model_diff(&form, &user), // JSON diff +}; +change.insert(&pool, "entity_changes_users").await?; + +// Diff methods +User::model_diff(&form, &existing) // Compare form with model +User::db_diff(&form, &id, &pool).await? // Compare form with DB +User::diff_modify(&mut form, &existing) // Modify form to only include changes +user.to_update_form() // Convert entity to UpdateForm +user.initial_diff() // Full entity as JSON +``` + +## Pagination + +```rust +use sqlx_record::prelude::{Page, PageRequest}; + +let page_req = PageRequest::new(1, 20); // page 1, 20 items per page +let page: Page = User::paginate( + &pool, filters![], None, vec![("name", true)], page_req +).await?; + +page.items // Vec +page.total_count // u64 +page.page // u32 (1-indexed) +page.page_size // u32 +page.total_pages() // u32 +page.has_next() // bool +page.has_prev() // bool +page.is_empty() // bool +page.len() // usize +``` + +## Transactions + +```rust +use sqlx_record::transaction; + +// Automatically commits on success, rolls back on error +let order_id = transaction!(&pool, |tx| { + user.insert(&mut *tx).await?; + order.insert(&mut *tx).await?; + Ok::<_, sqlx::Error>(order.id) +}).await?; +``` + +## ConnProvider - Flexible Connection Management + +```rust +use sqlx_record::prelude::ConnProvider; + +// From borrowed connection +let mut conn = pool.acquire().await?; +let mut provider = ConnProvider::from_ref(&mut conn); + +// From pool (lazy acquisition on first use) +let mut provider = ConnProvider::from_pool(pool.clone()); + +// From transaction +let mut tx = pool.begin().await?; +let mut provider = ConnProvider::from_tx(&mut tx); +// ... operations participate in the transaction ... +tx.commit().await?; + +// Get underlying connection +let conn = provider.get_conn().await?; +``` + +## Database Differences + +| Feature | MySQL | PostgreSQL | SQLite | +|---------|-------|------------|--------| +| Placeholder | `?` | `$1, $2` | `?` | +| Table quote | `` ` `` | `"` | `"` | +| UUID type | `BINARY(16)` | `UUID` | `BLOB` | +| JSON type | `JSON` | `JSONB` | `TEXT` | +| ILIKE | `LOWER() LIKE LOWER()` | Native | `LOWER() LIKE LOWER()` | +| Index hints | `USE INDEX()` | N/A | N/A | +| Unsigned ints | Native | Cast to signed | Cast to signed | + +**Unsigned integer conversion (PostgreSQL/SQLite):** +- `u8` -> `i16`, `u16` -> `i32`, `u32` -> `i64`, `u64` -> `i64` + +## Value Types + +The `Value` enum wraps all supported database types: +- Integers: `Int8`, `Uint8`, `Int16`, `Uint16`, `Int32`, `Uint32`, `Int64`, `Uint64` +- `String`, `Bool`, `VecU8`, `Uuid` +- `NaiveDate`, `NaiveDateTime` + +```rust +// Implicit conversion via Into +let v: Value = "hello".into(); // String +let v: Value = 42i64.into(); // Int64 +let v: Value = true.into(); // Bool +let v: Value = uuid.into(); // Uuid + +// values! macro for collections +let vals = values![1, "hello", true]; +``` + +## Common Patterns + +### Repository with Audit Trail +```rust +pub async fn update_user( + pool: &Pool, + user_id: &Uuid, + form: UserUpdateForm, + actor_id: &Uuid, + session_id: &Uuid, +) -> Result<(), Error> { + let diff = User::db_diff(&form, user_id, pool).await?; + User::update_by_id(pool, user_id, form).await?; + + let change = EntityChange { + id: new_uuid(), + entity_id: *user_id, + action: Action::Update, + changed_at: now_millis(), + actor_id: Some(*actor_id), + session_id: Some(*session_id), + change_set_id: None, + new_value: diff, + }; + change.insert(pool, "entity_changes_users").await?; + Ok(()) +} +``` + +### Filtered Search with Pagination +```rust +pub async fn search_users( + pool: &Pool, + query: Option<&str>, + role: Option<&str>, + page: u32, + page_size: u32, +) -> Result, Error> { + let mut filters = vec![Filter::Equal("is_active", true.into())]; + + if let Some(q) = query { + filters.push(Filter::ILike("name", format!("%{q}%").into())); + } + if let Some(r) = role { + filters.push(Filter::Equal("role", r.into())); + } + + User::paginate( + pool, filters, None, + vec![("name", true)], + PageRequest::new(page, page_size), + ).await +} +``` + +### Lookup-Driven Status Workflow +```rust +lookup_table!(OrderStatus, "pending", "processing", "shipped", "delivered", "cancelled"); + +pub async fn transition_order( + pool: &Pool, + order_id: &Uuid, + new_status: OrderStatusCode, +) -> Result<(), Error> { + let order = Order::get_by_id(pool, order_id).await?.unwrap(); + let current = OrderStatusCode::try_from(order.status.as_str())?; + + // Validate transition + let valid = matches!( + (current, new_status), + (OrderStatusCode::Pending, OrderStatusCode::Processing) + | (OrderStatusCode::Processing, OrderStatusCode::Shipped) + | (OrderStatusCode::Shipped, OrderStatusCode::Delivered) + | (_, OrderStatusCode::Cancelled) + ); + + if !valid { + return Err(Error::Protocol("Invalid status transition".into())); + } + + Order::update_by_id(pool, order_id, + Order::update_form().with_status(new_status.as_str().to_string()) + ).await +} +``` + +## Feature Flags Reference + +```toml +[dependencies] +sqlx-record = { version = "0.3", features = ["mysql", "derive"] } +``` + +| Flag | Description | +|------|-------------| +| `mysql` | MySQL/MariaDB/TiDB support | +| `postgres` | PostgreSQL support | +| `sqlite` | SQLite support | +| `derive` | `#[derive(Entity, Update)]` macros | +| `static-validation` | Compile-time SQLx query validation | + +**Must enable at least one database feature.** + +## CLI Tool (sqlx-record-ctl) + +```bash +# Generate audit table for an entity +sqlx-record-ctl generate-audit-table users + +# List auditable entities +sqlx-record-ctl list-entities +``` + +Requires `entity_changes_metadata` table with `table_name` and `is_auditable` columns. diff --git a/mcp/src/main.rs b/mcp/src/main.rs index 3960504..5711c8c 100644 --- a/mcp/src/main.rs +++ b/mcp/src/main.rs @@ -5,7 +5,10 @@ use std::io::{self, BufRead, Write}; #[derive(Parser)] #[command(name = "sqlx-record-mcp")] -#[command(version, about = "MCP server for sqlx-record documentation and code generation")] +#[command( + version, + about = "MCP server for sqlx-record documentation and code generation" +)] struct Args {} // ============================================================================ @@ -1581,9 +1584,16 @@ async fn get_user_history(pool: &Pool, user_id: &Uuid) -> Result String { - let name = params.get("name").and_then(|v| v.as_str()).unwrap_or("Entity"); - let table = params.get("table").and_then(|v| v.as_str()).unwrap_or("entities"); - let fields: Vec<(&str, &str)> = params.get("fields") + let name = params + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or("Entity"); + let table = params + .get("table") + .and_then(|v| v.as_str()) + .unwrap_or("entities"); + let fields: Vec<(&str, &str)> = params + .get("fields") .and_then(|v| v.as_array()) .map(|arr| { arr.iter() @@ -1595,7 +1605,10 @@ fn generate_entity_code(params: &Value) -> String { .collect() }) .unwrap_or_default(); - let has_version = params.get("version").and_then(|v| v.as_bool()).unwrap_or(false); + let has_version = params + .get("version") + .and_then(|v| v.as_bool()) + .unwrap_or(false); let mut code = format!( r#"use sqlx_record::prelude::*; @@ -1624,7 +1637,8 @@ pub struct {} {{ } fn generate_filter_code(params: &Value) -> String { - let conditions: Vec = params.get("conditions") + let conditions: Vec = params + .get("conditions") .and_then(|v| v.as_array()) .map(|arr| { arr.iter() @@ -1652,7 +1666,10 @@ fn generate_filter_code(params: &Value) -> String { }) .unwrap_or_default(); - let logic = params.get("logic").and_then(|v| v.as_str()).unwrap_or("and"); + let logic = params + .get("logic") + .and_then(|v| v.as_str()) + .unwrap_or("and"); if conditions.is_empty() { return "filters![]".to_string(); @@ -1665,14 +1682,22 @@ fn generate_filter_code(params: &Value) -> String { } fn generate_lookup_code(params: &Value) -> String { - let name = params.get("name").and_then(|v| v.as_str()).unwrap_or("Status"); - let codes: Vec<&str> = params.get("codes") + let name = params + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or("Status"); + let codes: Vec<&str> = params + .get("codes") .and_then(|v| v.as_array()) .map(|arr| arr.iter().filter_map(|c| c.as_str()).collect()) .unwrap_or_default(); - let with_entity = params.get("with_entity").and_then(|v| v.as_bool()).unwrap_or(true); + let with_entity = params + .get("with_entity") + .and_then(|v| v.as_bool()) + .unwrap_or(true); - let codes_str = codes.iter() + let codes_str = codes + .iter() .map(|c| format!("\"{}\"", c)) .collect::>() .join(", "); @@ -1815,7 +1840,10 @@ fn handle_call_tool(params: &Value) -> Value { }) } "explain_feature" => { - let feature = arguments.get("feature").and_then(|v| v.as_str()).unwrap_or("overview"); + let feature = arguments + .get("feature") + .and_then(|v| v.as_str()) + .unwrap_or("overview"); let doc = match feature { "overview" => OVERVIEW, "derive" => DERIVE_ENTITY, diff --git a/sqlx-record-ctl/src/main.rs b/sqlx-record-ctl/src/main.rs index 4952353..49df607 100644 --- a/sqlx-record-ctl/src/main.rs +++ b/sqlx-record-ctl/src/main.rs @@ -63,14 +63,13 @@ async fn main() -> Result<(), sqlx::Error> { }); // Find all tables marked as auditable in the metadata table - let tables: Vec = sqlx::query( - "SELECT table_name FROM entity_changes_metadata WHERE is_auditable = TRUE" - ) - .fetch_all(&pool) - .await? - .iter() - .map(|row| row.get::("table_name")) - .collect(); + let tables: Vec = + sqlx::query("SELECT table_name FROM entity_changes_metadata WHERE is_auditable = TRUE") + .fetch_all(&pool) + .await? + .iter() + .map(|row| row.get::("table_name")) + .collect(); // Iterate over each table and create/delete an entity_changes table for table_name in tables { @@ -81,10 +80,16 @@ async fn main() -> Result<(), sqlx::Error> { println!("delete table: {}", entity_changes_table); #[cfg(feature = "mysql")] - let drop_stmt = format!("DROP TABLE IF EXISTS {}.{}", schema_name, entity_changes_table); + let drop_stmt = format!( + "DROP TABLE IF EXISTS {}.{}", + schema_name, entity_changes_table + ); #[cfg(feature = "postgres")] - let drop_stmt = format!("DROP TABLE IF EXISTS \"{}\".\"{}\"", schema_name, entity_changes_table); + let drop_stmt = format!( + "DROP TABLE IF EXISTS \"{}\".\"{}\"", + schema_name, entity_changes_table + ); #[cfg(feature = "sqlite")] let drop_stmt = format!("DROP TABLE IF EXISTS \"{}\"", entity_changes_table); @@ -108,28 +113,38 @@ async fn main() -> Result<(), sqlx::Error> { new_value JSON );", schema_name, entity_changes_table, - )).execute(&pool).await?; + )) + .execute(&pool) + .await?; // Create indexes sqlx::query(&format!( "CREATE INDEX IF NOT EXISTS idx_{}_entity_id ON {}.{} (entity_id);", entity_changes_table, schema_name, entity_changes_table, - )).execute(&pool).await?; + )) + .execute(&pool) + .await?; sqlx::query(&format!( "CREATE INDEX IF NOT EXISTS idx_{}_change_set_id ON {}.{} (change_set_id);", entity_changes_table, schema_name, entity_changes_table, - )).execute(&pool).await?; + )) + .execute(&pool) + .await?; sqlx::query(&format!( "CREATE INDEX IF NOT EXISTS idx_{}_session_id ON {}.{} (session_id);", entity_changes_table, schema_name, entity_changes_table, - )).execute(&pool).await?; + )) + .execute(&pool) + .await?; sqlx::query(&format!( "CREATE INDEX IF NOT EXISTS idx_{}_actor_id ON {}.{} (actor_id);", entity_changes_table, schema_name, entity_changes_table, - )).execute(&pool).await?; + )) + .execute(&pool) + .await?; sqlx::query(&format!( "CREATE INDEX IF NOT EXISTS idx_{}_entity_id_actor_id ON {}.{} (entity_id, actor_id);", @@ -157,7 +172,9 @@ async fn main() -> Result<(), sqlx::Error> { sqlx::query(&format!( r#"CREATE INDEX IF NOT EXISTS idx_{}_entity_id ON "{}"."{}" (entity_id);"#, entity_changes_table, schema_name, entity_changes_table, - )).execute(&pool).await?; + )) + .execute(&pool) + .await?; sqlx::query(&format!( r#"CREATE INDEX IF NOT EXISTS idx_{}_change_set_id ON "{}"."{}" (change_set_id);"#, @@ -167,12 +184,16 @@ async fn main() -> Result<(), sqlx::Error> { sqlx::query(&format!( r#"CREATE INDEX IF NOT EXISTS idx_{}_session_id ON "{}"."{}" (session_id);"#, entity_changes_table, schema_name, entity_changes_table, - )).execute(&pool).await?; + )) + .execute(&pool) + .await?; sqlx::query(&format!( r#"CREATE INDEX IF NOT EXISTS idx_{}_actor_id ON "{}"."{}" (actor_id);"#, entity_changes_table, schema_name, entity_changes_table, - )).execute(&pool).await?; + )) + .execute(&pool) + .await?; sqlx::query(&format!( r#"CREATE INDEX IF NOT EXISTS idx_{}_entity_id_actor_id ON "{}"."{}" (entity_id, actor_id);"#, @@ -200,22 +221,30 @@ async fn main() -> Result<(), sqlx::Error> { sqlx::query(&format!( r#"CREATE INDEX IF NOT EXISTS idx_{}_entity_id ON "{}" (entity_id);"#, entity_changes_table, entity_changes_table, - )).execute(&pool).await?; + )) + .execute(&pool) + .await?; sqlx::query(&format!( r#"CREATE INDEX IF NOT EXISTS idx_{}_change_set_id ON "{}" (change_set_id);"#, entity_changes_table, entity_changes_table, - )).execute(&pool).await?; + )) + .execute(&pool) + .await?; sqlx::query(&format!( r#"CREATE INDEX IF NOT EXISTS idx_{}_session_id ON "{}" (session_id);"#, entity_changes_table, entity_changes_table, - )).execute(&pool).await?; + )) + .execute(&pool) + .await?; sqlx::query(&format!( r#"CREATE INDEX IF NOT EXISTS idx_{}_actor_id ON "{}" (actor_id);"#, entity_changes_table, entity_changes_table, - )).execute(&pool).await?; + )) + .execute(&pool) + .await?; sqlx::query(&format!( r#"CREATE INDEX IF NOT EXISTS idx_{}_entity_id_actor_id ON "{}" (entity_id, actor_id);"#, diff --git a/sqlx-record-derive/src/lib.rs b/sqlx-record-derive/src/lib.rs index 57a9a5f..30fd075 100644 --- a/sqlx-record-derive/src/lib.rs +++ b/sqlx-record-derive/src/lib.rs @@ -3,12 +3,13 @@ extern crate proc_macro; use proc_macro::TokenStream; use proc_macro2::{Ident, TokenStream as TokenStream2}; -use quote::{quote, format_ident}; -use syn::{parse_macro_input, DeriveInput, Data, LitStr, Type, ImplGenerics, TypeGenerics, WhereClause}; +use quote::{format_ident, quote}; +use syn::{ + parse_macro_input, Data, DeriveInput, ImplGenerics, LitStr, Type, TypeGenerics, WhereClause, +}; use crate::string_utils::{pluralize, to_snake_case}; - struct EntityField { ident: Ident, db_name: String, @@ -33,7 +34,11 @@ fn parse_string_attr(attr: &syn::Attribute) -> Option { } syn::Meta::NameValue(nv) => { // #[attr = "value"] style - if let syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit), .. }) = &nv.value { + if let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(lit), + .. + }) = &nv.value + { Some(lit.value()) } else { None @@ -49,7 +54,19 @@ pub fn derive_update(input: TokenStream) -> TokenStream { derive_entity_internal(input) } -#[proc_macro_derive(Entity, attributes(rename, table_name, primary_key, version, field_type, soft_delete, created_at, updated_at))] +#[proc_macro_derive( + Entity, + attributes( + rename, + table_name, + primary_key, + version, + field_type, + soft_delete, + created_at, + updated_at + ) +)] pub fn derive_entity(input: TokenStream) -> TokenStream { derive_entity_internal(input) } @@ -97,21 +114,34 @@ fn db_arguments() -> TokenStream2 { /// Get table quote character fn table_quote() -> &'static str { #[cfg(feature = "postgres")] - { "\"" } + { + "\"" + } #[cfg(feature = "sqlite")] - { return "\""; } + { + return "\""; + } #[cfg(feature = "mysql")] - { "`" } + { + "`" + } #[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))] - { "`" } + { + "`" + } } /// Get compile-time placeholder for static-check SQL fn static_placeholder(index: usize) -> String { #[cfg(feature = "postgres")] - { format!("${}", index) } + { + format!("${}", index) + } #[cfg(not(feature = "postgres"))] - { let _ = index; "?".to_string() } + { + let _ = index; + "?".to_string() + } } fn derive_entity_internal(input: TokenStream) -> TokenStream { @@ -129,31 +159,91 @@ fn derive_entity_internal(input: TokenStream) -> TokenStream { .expect("Struct must have a primary key field, either explicitly specified or named 'id' or 'code'"); // Check for timestamp fields - either by attribute or by name - let has_created_at = fields.iter().any(|f| f.is_created_at) || - fields.iter().any(|f| f.ident == "created_at" && matches!(&f.ty, Type::Path(p) if p.path.is_ident("i64"))); - let has_updated_at = fields.iter().any(|f| f.is_updated_at) || - fields.iter().any(|f| f.ident == "updated_at" && matches!(&f.ty, Type::Path(p) if p.path.is_ident("i64"))); + let has_created_at = fields.iter().any(|f| f.is_created_at) + || fields.iter().any(|f| { + f.ident == "created_at" && matches!(&f.ty, Type::Path(p) if p.path.is_ident("i64")) + }); + let has_updated_at = fields.iter().any(|f| f.is_updated_at) + || fields.iter().any(|f| { + f.ident == "updated_at" && matches!(&f.ty, Type::Path(p) if p.path.is_ident("i64")) + }); - let version_field = fields.iter() + let version_field = fields + .iter() .find(|f| f.is_version_field) .or_else(|| fields.iter().find(|&f| is_version_field(f))); // Find soft delete field (by attribute or by name convention) // Convention: `is_active` (FALSE = deleted), `is_deleted`/`deleted` (TRUE = deleted) - let soft_delete_field = fields.iter() - .find(|f| f.is_soft_delete) - .or_else(|| fields.iter().find(|f| { - (f.ident == "is_active" || f.ident == "is_deleted" || f.ident == "deleted") && - matches!(&f.ty, Type::Path(p) if p.path.is_ident("bool")) - })); + let soft_delete_field = fields.iter().find(|f| f.is_soft_delete).or_else(|| { + fields.iter().find(|f| { + (f.ident == "is_active" || f.ident == "is_deleted" || f.ident == "deleted") + && matches!(&f.ty, Type::Path(p) if p.path.is_ident("bool")) + }) + }); // Generate all implementations - let insert_impl = generate_insert_impl(&name, &table_name, primary_key, &fields, has_created_at, has_updated_at, &impl_generics, &ty_generics, &where_clause); - let get_impl = generate_get_impl(&name, &table_name, primary_key, version_field, soft_delete_field, &fields, &impl_generics, &ty_generics, &where_clause); - let update_impl = generate_update_impl(&name, &update_form_name, &table_name, &fields, primary_key, version_field, has_updated_at, &impl_generics, &ty_generics, &where_clause); - let diff_impl = generate_diff_impl(&name, &update_form_name, &fields, primary_key, version_field, &impl_generics, &ty_generics, &where_clause); - let delete_impl = generate_delete_impl(&name, &table_name, primary_key, &impl_generics, &ty_generics, &where_clause); - let soft_delete_impl = generate_soft_delete_impl(&name, &table_name, primary_key, soft_delete_field, &impl_generics, &ty_generics, &where_clause); + let insert_impl = generate_insert_impl( + &name, + &table_name, + primary_key, + &fields, + has_created_at, + has_updated_at, + &impl_generics, + &ty_generics, + &where_clause, + ); + let get_impl = generate_get_impl( + &name, + &table_name, + primary_key, + version_field, + soft_delete_field, + &fields, + &impl_generics, + &ty_generics, + &where_clause, + ); + let update_impl = generate_update_impl( + &name, + &update_form_name, + &table_name, + &fields, + primary_key, + version_field, + has_updated_at, + &impl_generics, + &ty_generics, + &where_clause, + ); + let diff_impl = generate_diff_impl( + &name, + &update_form_name, + &fields, + primary_key, + version_field, + &impl_generics, + &ty_generics, + &where_clause, + ); + let delete_impl = generate_delete_impl( + &name, + &table_name, + primary_key, + &impl_generics, + &ty_generics, + &where_clause, + ); + let soft_delete_impl = generate_soft_delete_impl( + &name, + &table_name, + primary_key, + soft_delete_field, + &impl_generics, + &ty_generics, + &where_clause, + ); let pk_type = &primary_key.ty; let pk_field_name = &primary_key.ident; @@ -179,10 +269,13 @@ fn derive_entity_internal(input: TokenStream) -> TokenStream { format!("entity_changes_{}", #table_name) } } - }.into() + } + .into() } fn get_table_name(input: &DeriveInput) -> String { - input.attrs.iter() + input + .attrs + .iter() .find_map(|attr| { if attr.path().is_ident("table_name") { parse_string_attr(attr) @@ -195,10 +288,14 @@ fn get_table_name(input: &DeriveInput) -> String { fn parse_fields(input: &DeriveInput) -> Vec { match &input.data { - Data::Struct(data_struct) => { - data_struct.fields.iter().map(|field| { + Data::Struct(data_struct) => data_struct + .fields + .iter() + .map(|field| { let ident = field.ident.as_ref().unwrap().clone(); - let db_name = field.attrs.iter() + let db_name = field + .attrs + .iter() .find_map(|attr| { if attr.path().is_ident("rename") { parse_string_attr(attr) @@ -208,14 +305,13 @@ fn parse_fields(input: &DeriveInput) -> Vec { }) .unwrap_or_else(|| ident.to_string()); - let type_override = field.attrs.iter() - .find_map(|attr| { - if attr.path().is_ident("field_type") { - parse_string_attr(attr) - } else { - None - } - }); + let type_override = field.attrs.iter().find_map(|attr| { + if attr.path().is_ident("field_type") { + parse_string_attr(attr) + } else { + None + } + }); let needs_type_annotation = type_override.is_some() || { matches!(&field.ty, syn::Type::Path(p) if { @@ -224,15 +320,25 @@ fn parse_fields(input: &DeriveInput) -> Vec { }) }; - let is_primary_key = field.attrs.iter() + let is_primary_key = field + .attrs + .iter() .any(|attr| attr.path().is_ident("primary_key")); - let is_version_field = field.attrs.iter() + let is_version_field = field + .attrs + .iter() .any(|attr| attr.path().is_ident("version")); - let is_soft_delete = field.attrs.iter() + let is_soft_delete = field + .attrs + .iter() .any(|attr| attr.path().is_ident("soft_delete")); - let is_created_at = field.attrs.iter() + let is_created_at = field + .attrs + .iter() .any(|attr| attr.path().is_ident("created_at")); - let is_updated_at = field.attrs.iter() + let is_updated_at = field + .attrs + .iter() .any(|attr| attr.path().is_ident("updated_at")); EntityField { @@ -247,14 +353,15 @@ fn parse_fields(input: &DeriveInput) -> Vec { is_created_at, is_updated_at, } - }).collect() - } + }) + .collect(), _ => panic!("Entity can only be derived for structs"), } } fn is_version_field(f: &EntityField) -> bool { - f.ident == "version" && matches!(&f.ty, Type::Path(p) if p.path.is_ident("u64") || + f.ident == "version" + && matches!(&f.ty, Type::Path(p) if p.path.is_ident("u64") || p.path.is_ident("u32") || p.path.is_ident("i64") || p.path.is_ident("i32")) } @@ -276,10 +383,13 @@ fn generate_insert_impl( let db = db_type(); let pk_db_name = &primary_key.db_name; - let bindings: Vec<_> = fields.iter().map(|f| { - let ident = &f.ident; - quote! { &self.#ident } - }).collect(); + let bindings: Vec<_> = fields + .iter() + .map(|f| { + let ident = &f.ident; + quote! { &self.#ident } + }) + .collect(); let pk_field = &primary_key.ident; let pk_type = &primary_key.ty; @@ -410,7 +520,7 @@ fn get_type_string(field: &EntityField) -> String { if clean_type.starts_with("Option<") && clean_type.ends_with(">") { // Extract inner type between < and > - clean_type[7..clean_type.len()-1].to_string() + clean_type[7..clean_type.len() - 1].to_string() } else { type_str } @@ -426,7 +536,7 @@ fn generate_get_impl( table_name: &str, primary_key: &EntityField, version_field: Option<&EntityField>, - _soft_delete_field: Option<&EntityField>, // Reserved for future auto-filtering + _soft_delete_field: Option<&EntityField>, // Reserved for future auto-filtering fields: &[EntityField], impl_generics: &ImplGenerics, ty_generics: &TypeGenerics, @@ -456,9 +566,11 @@ fn generate_get_impl( let db = db_type(); let new_fields = select_fields.clone().collect::>(); - let select_fields_str = new_fields.iter() - .filter_map(|e| e.split(" ") - .next()).collect::>().join(", "); + let select_fields_str = new_fields + .iter() + .filter_map(|e| e.split(" ").next()) + .collect::>() + .join(", "); let select_field_list = select_fields.clone().collect::>(); @@ -528,8 +640,6 @@ fn generate_get_impl( Ok(result) } } - - } else { // If no version field, generate empty implementation quote! {} @@ -545,7 +655,10 @@ fn generate_get_impl( let select_stmt = format!( r#"SELECT DISTINCT {} FROM {}{}{} WHERE {} = {}"#, select_fields.clone().collect::>().join(", "), - tq, table_name, tq, pk_db_field_name, + tq, + table_name, + tq, + pk_db_field_name, static_placeholder(1) ); quote! { @@ -922,9 +1035,16 @@ fn generate_update_impl( ty_generics: &TypeGenerics, where_clause: &Option<&WhereClause>, ) -> TokenStream2 { - let update_fields: Vec<_> = fields.iter() - .filter(|f| f.ident != primary_key.ident && f.ident != "created_at" && - version_field.as_ref().map(|vf| f.ident != vf.ident).unwrap_or(true)) + let update_fields: Vec<_> = fields + .iter() + .filter(|f| { + f.ident != primary_key.ident + && f.ident != "created_at" + && version_field + .as_ref() + .map(|vf| f.ident != vf.ident) + .unwrap_or(true) + }) .collect(); let field_idents: Vec<_> = update_fields.iter().map(|f| &f.ident).collect(); @@ -933,19 +1053,22 @@ fn generate_update_impl( let db = db_type(); let db_args = db_arguments(); - let setter_methods: Vec<_> = update_fields.iter().map(|field| { - let method_name = format_ident!("set_{}", field.ident); - let field_type = &field.ty; - let field_ident = &field.ident; - quote! { - pub fn #method_name(&mut self, value: T) -> () - where - T: Into<#field_type>, - { - self.#field_ident = Some(value.into()); + let setter_methods: Vec<_> = update_fields + .iter() + .map(|field| { + let method_name = format_ident!("set_{}", field.ident); + let field_type = &field.ty; + let field_ident = &field.ident; + quote! { + pub fn #method_name(&mut self, value: T) -> () + where + T: Into<#field_type>, + { + self.#field_ident = Some(value.into()); + } } - } - }).collect(); + }) + .collect(); let builder_methods = update_fields.iter().map(|field| { let method_name = format_ident!("with_{}", field.ident); @@ -963,7 +1086,8 @@ fn generate_update_impl( }); // Generate eval_* methods for non-binary fields - let eval_methods: Vec<_> = update_fields.iter() + let eval_methods: Vec<_> = update_fields + .iter() .filter(|f| !is_binary_type(&f.ty)) .map(|field| { let method_name = format_ident!("eval_{}", field.ident); @@ -1163,9 +1287,16 @@ fn generate_diff_impl( ty_generics: &TypeGenerics, where_clause: &Option<&WhereClause>, ) -> TokenStream2 { - let update_fields: Vec<_> = fields.iter() - .filter(|f| f.ident != primary_key.ident && f.ident != "created_at" && - version_field.as_ref().map(|vf| f.ident != vf.ident).unwrap_or(true)) + let update_fields: Vec<_> = fields + .iter() + .filter(|f| { + f.ident != primary_key.ident + && f.ident != "created_at" + && version_field + .as_ref() + .map(|vf| f.ident != vf.ident) + .unwrap_or(true) + }) .collect(); let field_idents: Vec<_> = update_fields.iter().map(|f| &f.ident).collect(); @@ -1178,7 +1309,7 @@ fn generate_diff_impl( let pk_type = primary_key.ty.clone(); let pk_db_name = primary_key.db_name.clone(); - let multi_pk_field= format_ident!("{}", pluralize(&pk_field.to_string())); + let multi_pk_field = format_ident!("{}", pluralize(&pk_field.to_string())); let update_by_func = format_ident!("update_by_{}", pk_field); let multi_update_by_func = format_ident!("update_by_{}", multi_pk_field); diff --git a/sqlx-record-derive/src/string_utils.rs b/sqlx-record-derive/src/string_utils.rs index 78f4257..a517f09 100644 --- a/sqlx-record-derive/src/string_utils.rs +++ b/sqlx-record-derive/src/string_utils.rs @@ -26,8 +26,8 @@ pub(crate) fn pluralize(word: &str) -> String { } // Handle possessives and existing plurals - if word.ends_with("'s") || word.ends_with("'") || - word.ends_with("s's") || word.ends_with("s'") { + if word.ends_with("'s") || word.ends_with("'") || word.ends_with("s's") || word.ends_with("s'") + { return word.to_string(); } @@ -40,19 +40,15 @@ pub(crate) fn pluralize(word: &str) -> String { // Compound words with hyphens if word.contains('-') { let parts: Vec<&str> = word.split('-').collect(); - return format!("{}-{}", - pluralize(parts[0]), - parts[1..].join("-") - ); + return format!("{}-{}", pluralize(parts[0]), parts[1..].join("-")); } // Invariant words (same singular and plural) match word.to_lowercase().as_str() { - "sheep" | "deer" | "moose" | "swine" | "buffalo" | "fish" | "trout" | - "salmon" | "pike" | "aircraft" | "series" | "species" | "means" | - "crossroads" | "swiss" | "portuguese" | "vietnamese" | "japanese" | - "chinese" | "chassis" | "corps" | "headquarters" | "diabetes" | - "news" | "odds" | "innings" => return word.to_string(), + "sheep" | "deer" | "moose" | "swine" | "buffalo" | "fish" | "trout" | "salmon" | "pike" + | "aircraft" | "series" | "species" | "means" | "crossroads" | "swiss" | "portuguese" + | "vietnamese" | "japanese" | "chinese" | "chassis" | "corps" | "headquarters" + | "diabetes" | "news" | "odds" | "innings" => return word.to_string(), _ => {} } @@ -103,7 +99,8 @@ pub(crate) fn pluralize(word: &str) -> String { "millennium" => "millennia", _ => return apply_general_rules(word), - }.to_string() + } + .to_string() } fn apply_general_rules(word: &str) -> String { @@ -111,9 +108,22 @@ fn apply_general_rules(word: &str) -> String { if word.ends_with('o') { match word.to_lowercase().as_str() { // -o → -oes - w if matches!(w, "hero" | "potato" | "tomato" | "echo" | - "tornado" | "torpedo" | "veto" | "mosquito" | - "volcano" | "buffalo" | "domino" | "embargo") => { + w if matches!( + w, + "hero" + | "potato" + | "tomato" + | "echo" + | "tornado" + | "torpedo" + | "veto" + | "mosquito" + | "volcano" + | "buffalo" + | "domino" + | "embargo" + ) => + { return format!("{}es", word); } // -o → -os @@ -147,12 +157,16 @@ fn apply_general_rules(word: &str) -> String { } // Words ending in sibilants (-s, -ss, -sh, -ch, -x, -z) - if word.ends_with('s') || word.ends_with("ss") || - word.ends_with("sh") || word.ends_with("ch") || - word.ends_with('x') || word.ends_with('z') { + if word.ends_with('s') + || word.ends_with("ss") + || word.ends_with("sh") + || word.ends_with("ch") + || word.ends_with('x') + || word.ends_with('z') + { return format!("{}es", word); } // Default case: add 's' format!("{}s", word) -} \ No newline at end of file +} diff --git a/src/conn_provider.rs b/src/conn_provider.rs index d2271f0..3b23fcd 100644 --- a/src/conn_provider.rs +++ b/src/conn_provider.rs @@ -4,7 +4,7 @@ use sqlx::pool::PoolConnection; use sqlx::{MySql, MySqlConnection, MySqlPool, Transaction}; #[cfg(feature = "postgres")] -use sqlx::{Postgres, PgConnection, PgPool, Transaction}; +use sqlx::{PgConnection, PgPool, Postgres, Transaction}; #[cfg(feature = "sqlite")] use sqlx::{Sqlite, SqliteConnection, SqlitePool, Transaction}; @@ -16,9 +16,7 @@ use sqlx::{Sqlite, SqliteConnection, SqlitePool, Transaction}; #[cfg(feature = "mysql")] pub enum ConnProvider<'a> { /// Stores a reference to an existing connection - Borrowed { - conn: &'a mut PoolConnection, - }, + Borrowed { conn: &'a mut PoolConnection }, /// Stores an owned connection acquired from a pool Owned { pool: MySqlPool, diff --git a/src/filter.rs b/src/filter.rs index 341d9d3..c702b61 100644 --- a/src/filter.rs +++ b/src/filter.rs @@ -115,13 +115,21 @@ pub fn placeholder(index: usize) -> String { #[inline] pub fn table_quote() -> &'static str { #[cfg(feature = "mysql")] - { "`" } + { + "`" + } #[cfg(feature = "postgres")] - { "\"" } + { + "\"" + } #[cfg(feature = "sqlite")] - { "\"" } + { + "\"" + } #[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))] - { "`" } + { + "`" + } } /// Builds an index hint clause (MySQL-specific, empty for other databases) @@ -129,7 +137,9 @@ pub fn table_quote() -> &'static str { pub fn build_index_clause(index: Option<&str>) -> String { #[cfg(feature = "mysql")] { - index.map(|idx| format!("USE INDEX ({})", idx)).unwrap_or_default() + index + .map(|idx| format!("USE INDEX ({})", idx)) + .unwrap_or_default() } #[cfg(not(feature = "mysql"))] { @@ -213,7 +223,7 @@ pub fn build_upsert_stmt( #[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))] { let _ = pk_field; // Not used in MySQL ON DUPLICATE KEY syntax - // Fallback to MySQL syntax + // Fallback to MySQL syntax let update_clause = non_pk_fields .iter() .map(|f| format!("{} = VALUES({})", f, f)) @@ -252,7 +262,10 @@ impl Filter<'_> { Self::build_where_clause_with_offset(filters, 1) } - pub fn build_where_clause_with_offset(filters: &[Filter], start_index: usize) -> (String, Vec) { + pub fn build_where_clause_with_offset( + filters: &[Filter], + start_index: usize, + ) -> (String, Vec) { let mut values = Vec::new(); let mut current_index = start_index; @@ -323,20 +336,26 @@ impl Filter<'_> { format!("{} NOT LIKE {}", field, ph) } Filter::In(field, value_vec) => { - let placeholders: Vec = value_vec.iter().map(|_| { - let ph = placeholder(current_index); - current_index += 1; - ph - }).collect(); + let placeholders: Vec = value_vec + .iter() + .map(|_| { + let ph = placeholder(current_index); + current_index += 1; + ph + }) + .collect(); values.extend(value_vec.clone()); format!("{} IN ({})", field, placeholders.join(", ")) } Filter::NotIn(field, value_vec) => { - let placeholders: Vec = value_vec.iter().map(|_| { - let ph = placeholder(current_index); - current_index += 1; - ph - }).collect(); + let placeholders: Vec = value_vec + .iter() + .map(|_| { + let ph = placeholder(current_index); + current_index += 1; + ph + }) + .collect(); values.extend(value_vec.clone()); format!("{} NOT IN ({})", field, placeholders.join(", ")) } @@ -347,16 +366,24 @@ impl Filter<'_> { format!("{} IS NOT NULL", field) } Filter::And(nested_filters) => { - let (nested_clause, nested_values) = Self::build_where_clause_with_offset(nested_filters, current_index); + let (nested_clause, nested_values) = + Self::build_where_clause_with_offset(nested_filters, current_index); current_index += nested_values.len(); values.extend(nested_values); format!("({})", nested_clause) } Filter::Or(nested_filters) => { - let (nested_clause, nested_values) = Self::build_where_clause_with_offset(nested_filters, current_index); + let (nested_clause, nested_values) = + Self::build_where_clause_with_offset(nested_filters, current_index); current_index += nested_values.len(); values.extend(nested_values); - format!("({})", nested_clause.split(" AND ").collect::>().join(" OR ")) + format!( + "({})", + nested_clause + .split(" AND ") + .collect::>() + .join(" OR ") + ) } } }) diff --git a/src/helpers.rs b/src/helpers.rs index 8397c97..f4bf0fc 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -1,23 +1,31 @@ #[macro_export] macro_rules! update_entity_func { ($form_type:ident, $func_name:ident) => { - pub async fn $func_name<'a, E>(executor: E, id: &Uuid, form: $form_type) -> Result<(), RepositoryError> + pub async fn $func_name<'a, E>( + executor: E, + id: &Uuid, + form: $form_type, + ) -> Result<(), RepositoryError> where E: sqlx::Executor<'a, Database = $crate::prelude::DB>, { //If the section exists, we update it let result = sqlx::query( - format!(r#" + format!( + r#" UPDATE {} SET {} WHERE id = ? "#, - $form_type::table_name(), - form.update_stmt()).as_str()) - .bind_form_values(form) - .bind(id) - .execute(executor) - .await; + $form_type::table_name(), + form.update_stmt() + ) + .as_str(), + ) + .bind_form_values(form) + .bind(id) + .execute(executor) + .await; if let Err(err) = result { tracing::error!("Error updating entity: {:?}", err); diff --git a/src/lib.rs b/src/lib.rs index bbc78ab..6796245 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,14 +2,14 @@ use chrono::Utc; use rand::random; use uuid::Uuid; -pub mod models; -pub mod repositories; -mod helpers; -mod value; -mod filter; mod conn_provider; +mod filter; +mod helpers; +pub mod models; mod pagination; +pub mod repositories; mod transaction; +mod value; pub use pagination::{Page, PageRequest}; // transaction! macro is exported via #[macro_export] in transaction.rs @@ -174,14 +174,13 @@ macro_rules! lookup_options { } pub mod prelude { - pub use crate::value::*; pub use crate::filter::*; - pub use crate::{filter_or, filter_and, filters, update_entity_func}; - pub use crate::{filter_or as or, filter_and as and}; - pub use crate::values; - pub use crate::{new_uuid, lookup_table, lookup_options, transaction}; pub use crate::pagination::{Page, PageRequest}; - pub use crate::conn_provider::*; + pub use crate::value::*; + pub use crate::values; + pub use crate::{filter_and as and, filter_or as or}; + pub use crate::{filter_and, filter_or, filters, update_entity_func}; + pub use crate::{lookup_options, lookup_table, new_uuid, transaction}; #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))] pub use crate::conn_provider::ConnProvider; diff --git a/src/models.rs b/src/models.rs index 42a0c9d..ea8ad02 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,7 +1,7 @@ -use std::fmt::Display; -use sqlx::FromRow; -use uuid::Uuid; use serde_json::Value; +use sqlx::FromRow; +use std::fmt::Display; +use uuid::Uuid; #[derive(Debug, FromRow)] pub struct EntityChange { @@ -50,4 +50,4 @@ impl Display for Action { }; write!(f, "{}", str) } -} \ No newline at end of file +} diff --git a/src/pagination.rs b/src/pagination.rs index a23544e..313e758 100644 --- a/src/pagination.rs +++ b/src/pagination.rs @@ -13,7 +13,12 @@ pub struct Page { impl Page { pub fn new(items: Vec, total_count: u64, page: u32, page_size: u32) -> Self { - Self { items, total_count, page, page_size } + Self { + items, + total_count, + page, + page_size, + } } /// Total number of pages @@ -93,7 +98,11 @@ impl PageRequest { /// Calculate SQL OFFSET (0-indexed) pub fn offset(&self) -> u32 { - if self.page <= 1 { 0 } else { (self.page - 1) * self.page_size } + if self.page <= 1 { + 0 + } else { + (self.page - 1) * self.page_size + } } /// Calculate SQL LIMIT diff --git a/src/repositories.rs b/src/repositories.rs index 9ffb66b..a2dace4 100644 --- a/src/repositories.rs +++ b/src/repositories.rs @@ -1,6 +1,6 @@ +use crate::models::EntityChange; use sqlx::Error; use uuid::Uuid; -use crate::models::EntityChange; #[cfg(feature = "mysql")] use sqlx::MySqlExecutor as Executor; @@ -29,13 +29,21 @@ fn ph(index: usize) -> String { #[inline] fn table_quote() -> &'static str { #[cfg(feature = "mysql")] - { "`" } + { + "`" + } #[cfg(feature = "postgres")] - { "\"" } + { + "\"" + } #[cfg(feature = "sqlite")] - { "\"" } + { + "\"" + } #[cfg(not(any(feature = "mysql", feature = "postgres", feature = "sqlite")))] - { "" } + { + "" + } } #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))] @@ -51,7 +59,15 @@ pub async fn create_entity_change<'q>( session_id, change_set_id, new_value) VALUES ({}, {}, {}, {}, {}, {}, {}, {})"#, table_name, - ph(1), ph(2), ph(3), ph(4), ph(5), ph(6), ph(7), ph(8)); + ph(1), + ph(2), + ph(3), + ph(4), + ph(5), + ph(6), + ph(7), + ph(8) + ); sqlx::query(&query) .bind(&change.id) @@ -71,7 +87,9 @@ pub async fn create_entity_change<'q>( #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))] pub async fn get_entity_changes_by_id<'q>( conn: impl Executor<'q>, - table_name: &str, id: &Uuid) -> Result, Error> { + table_name: &str, + id: &Uuid, +) -> Result, Error> { let q = table_quote(); let query = format!( r#"SELECT @@ -84,7 +102,9 @@ pub async fn get_entity_changes_by_id<'q>( change_set_id, new_value FROM {q}{}{q} WHERE id = {}"#, - table_name, ph(1)); + table_name, + ph(1) + ); let changes = sqlx::query_as::<_, EntityChange>(&query) .bind(id) @@ -97,7 +117,9 @@ pub async fn get_entity_changes_by_id<'q>( #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))] pub async fn get_entity_changes_by_entity<'q>( conn: impl Executor<'q>, - table_name: &str, entity_id: &Uuid) -> Result, Error> { + table_name: &str, + entity_id: &Uuid, +) -> Result, Error> { let q = table_quote(); let query = format!( r#"SELECT @@ -110,7 +132,9 @@ pub async fn get_entity_changes_by_entity<'q>( change_set_id, new_value FROM {q}{}{q} WHERE entity_id = {}"#, - table_name, ph(1)); + table_name, + ph(1) + ); let changes = sqlx::query_as::<_, EntityChange>(&query) .bind(entity_id) @@ -123,8 +147,9 @@ pub async fn get_entity_changes_by_entity<'q>( #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))] pub async fn get_entity_changes_session<'q>( conn: impl Executor<'q>, - table_name: &str, session_id: &Uuid, -) -> Result, Error>{ + table_name: &str, + session_id: &Uuid, +) -> Result, Error> { let q = table_quote(); let query = format!( r#"SELECT @@ -137,7 +162,9 @@ pub async fn get_entity_changes_session<'q>( change_set_id, new_value FROM {q}{}{q} WHERE session_id = {}"#, - table_name, ph(1)); + table_name, + ph(1) + ); let changes = sqlx::query_as::<_, EntityChange>(&query) .bind(session_id) @@ -150,7 +177,9 @@ pub async fn get_entity_changes_session<'q>( #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))] pub async fn get_entity_changes_actor<'q>( conn: impl Executor<'q>, - table_name: &str, actor_id: &Uuid) -> Result, Error>{ + table_name: &str, + actor_id: &Uuid, +) -> Result, Error> { let q = table_quote(); let query = format!( r#"SELECT @@ -163,7 +192,9 @@ pub async fn get_entity_changes_actor<'q>( change_set_id, new_value FROM {q}{}{q} WHERE actor_id = {}"#, - table_name, ph(1)); + table_name, + ph(1) + ); let changes = sqlx::query_as::<_, EntityChange>(&query) .bind(actor_id) @@ -175,8 +206,10 @@ pub async fn get_entity_changes_actor<'q>( #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))] pub async fn get_entity_changes_by_change_set<'q>( - conn: impl Executor<'q>, table_name: &str, change_set_id: &Uuid) -> Result, Error> -{ + conn: impl Executor<'q>, + table_name: &str, + change_set_id: &Uuid, +) -> Result, Error> { let q = table_quote(); let query = format!( r#"SELECT @@ -189,7 +222,9 @@ pub async fn get_entity_changes_by_change_set<'q>( change_set_id, new_value FROM {q}{}{q} WHERE change_set_id = {}"#, - table_name, ph(1)); + table_name, + ph(1) + ); let changes = sqlx::query_as::<_, EntityChange>(&query) .bind(change_set_id) @@ -201,7 +236,11 @@ pub async fn get_entity_changes_by_change_set<'q>( #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))] pub async fn get_entity_changes_by_entity_and_actor<'q>( - conn: impl Executor<'q>, table_name: &str, entity_id: &Uuid, actor_id: &Uuid) -> Result, Error>{ + conn: impl Executor<'q>, + table_name: &str, + entity_id: &Uuid, + actor_id: &Uuid, +) -> Result, Error> { let q = table_quote(); let query = format!( r#"SELECT @@ -214,7 +253,10 @@ pub async fn get_entity_changes_by_entity_and_actor<'q>( change_set_id, new_value FROM {q}{}{q} WHERE entity_id = {} AND actor_id = {}"#, - table_name, ph(1), ph(2)); + table_name, + ph(1), + ph(2) + ); let changes = sqlx::query_as::<_, EntityChange>(&query) .bind(entity_id) diff --git a/src/value.rs b/src/value.rs index 7c52857..4f52fbe 100644 --- a/src/value.rs +++ b/src/value.rs @@ -1,6 +1,6 @@ +use crate::filter::placeholder; use sqlx::query::{Query, QueryAs, QueryScalar}; use sqlx::types::chrono::{NaiveDate, NaiveDateTime, NaiveTime}; -use crate::filter::placeholder; // Database type alias based on enabled feature #[cfg(feature = "mysql")] @@ -110,10 +110,7 @@ pub enum UpdateExpr { /// Raw SQL expression escape hatch: column = {sql} /// Placeholders in sql should use `?` and will be replaced with proper placeholders - Raw { - sql: String, - values: Vec, - }, + Raw { sql: String, values: Vec }, } impl UpdateExpr { @@ -123,34 +120,35 @@ impl UpdateExpr { use crate::filter::Filter; match self { - UpdateExpr::Set(v) => { - (placeholder(start_idx), vec![v.clone()]) - } - UpdateExpr::Add(v) => { - (format!("{} + {}", column, placeholder(start_idx)), vec![v.clone()]) - } - UpdateExpr::Sub(v) => { - (format!("{} - {}", column, placeholder(start_idx)), vec![v.clone()]) - } - UpdateExpr::Mul(v) => { - (format!("{} * {}", column, placeholder(start_idx)), vec![v.clone()]) - } - UpdateExpr::Div(v) => { - (format!("{} / {}", column, placeholder(start_idx)), vec![v.clone()]) - } - UpdateExpr::Mod(v) => { - (format!("{} % {}", column, placeholder(start_idx)), vec![v.clone()]) - } + UpdateExpr::Set(v) => (placeholder(start_idx), vec![v.clone()]), + UpdateExpr::Add(v) => ( + format!("{} + {}", column, placeholder(start_idx)), + vec![v.clone()], + ), + UpdateExpr::Sub(v) => ( + format!("{} - {}", column, placeholder(start_idx)), + vec![v.clone()], + ), + UpdateExpr::Mul(v) => ( + format!("{} * {}", column, placeholder(start_idx)), + vec![v.clone()], + ), + UpdateExpr::Div(v) => ( + format!("{} / {}", column, placeholder(start_idx)), + vec![v.clone()], + ), + UpdateExpr::Mod(v) => ( + format!("{} % {}", column, placeholder(start_idx)), + vec![v.clone()], + ), UpdateExpr::Case { branches, default } => { let mut sql_parts = vec!["CASE".to_string()]; let mut values = Vec::new(); let mut idx = start_idx; for (condition, value) in branches { - let (cond_sql, cond_values) = Filter::build_where_clause_with_offset( - &[condition.clone()], - idx, - ); + let (cond_sql, cond_values) = + Filter::build_where_clause_with_offset(&[condition.clone()], idx); idx += cond_values.len(); values.extend(cond_values); @@ -165,10 +163,8 @@ impl UpdateExpr { (sql_parts.join(" "), values) } UpdateExpr::AddIf { condition, value } => { - let (cond_sql, cond_values) = Filter::build_where_clause_with_offset( - &[condition.clone()], - start_idx, - ); + let (cond_sql, cond_values) = + Filter::build_where_clause_with_offset(&[condition.clone()], start_idx); let mut values = cond_values; let val_idx = start_idx + values.len(); @@ -184,10 +180,8 @@ impl UpdateExpr { (sql, values) } UpdateExpr::SubIf { condition, value } => { - let (cond_sql, cond_values) = Filter::build_where_clause_with_offset( - &[condition.clone()], - start_idx, - ); + let (cond_sql, cond_values) = + Filter::build_where_clause_with_offset(&[condition.clone()], start_idx); let mut values = cond_values; let val_idx = start_idx + values.len(); @@ -202,15 +196,18 @@ impl UpdateExpr { (sql, values) } - UpdateExpr::Coalesce(v) => { - (format!("COALESCE({}, {})", column, placeholder(start_idx)), vec![v.clone()]) - } - UpdateExpr::Greatest(v) => { - (format!("GREATEST({}, {})", column, placeholder(start_idx)), vec![v.clone()]) - } - UpdateExpr::Least(v) => { - (format!("LEAST({}, {})", column, placeholder(start_idx)), vec![v.clone()]) - } + UpdateExpr::Coalesce(v) => ( + format!("COALESCE({}, {})", column, placeholder(start_idx)), + vec![v.clone()], + ), + UpdateExpr::Greatest(v) => ( + format!("GREATEST({}, {})", column, placeholder(start_idx)), + vec![v.clone()], + ), + UpdateExpr::Least(v) => ( + format!("LEAST({}, {})", column, placeholder(start_idx)), + vec![v.clone()], + ), UpdateExpr::Raw { sql, values } => { // Replace ? placeholders with proper database placeholders let mut result_sql = String::new(); @@ -239,11 +236,24 @@ impl UpdateExpr { UpdateExpr::Mul(_) => 1, UpdateExpr::Div(_) => 1, UpdateExpr::Mod(_) => 1, - UpdateExpr::Case { branches, default: _ } => { - branches.iter().map(|(f, _)| f.param_count() + 1).sum::() + 1 + UpdateExpr::Case { + branches, + default: _, + } => { + branches + .iter() + .map(|(f, _)| f.param_count() + 1) + .sum::() + + 1 } - UpdateExpr::AddIf { condition, value: _ } => condition.param_count() + 1, - UpdateExpr::SubIf { condition, value: _ } => condition.param_count() + 1, + UpdateExpr::AddIf { + condition, + value: _, + } => condition.param_count() + 1, + UpdateExpr::SubIf { + condition, + value: _, + } => condition.param_count() + 1, UpdateExpr::Coalesce(_) => 1, UpdateExpr::Greatest(_) => 1, UpdateExpr::Least(_) => 1, @@ -319,7 +329,10 @@ macro_rules! bind_value { } #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))] -pub fn bind_values<'q>(query: Query<'q, DB, Arguments_<'q>>, values: &'q [Value]) -> Query<'q, DB, Arguments_<'q>> { +pub fn bind_values<'q>( + query: Query<'q, DB, Arguments_<'q>>, + values: &'q [Value], +) -> Query<'q, DB, Arguments_<'q>> { let mut query = query; for value in values { query = bind_value!(query, value); @@ -329,7 +342,10 @@ pub fn bind_values<'q>(query: Query<'q, DB, Arguments_<'q>>, values: &'q [Value] /// Bind a single owned Value to a query #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))] -pub fn bind_value_owned<'q>(query: Query<'q, DB, Arguments_<'q>>, value: Value) -> Query<'q, DB, Arguments_<'q>> { +pub fn bind_value_owned<'q>( + query: Query<'q, DB, Arguments_<'q>>, + value: Value, +) -> Query<'q, DB, Arguments_<'q>> { match value { Value::Null => query.bind(None::), Value::Int8(v) => query.bind(v), @@ -368,14 +384,20 @@ pub fn bind_value_owned<'q>(query: Query<'q, DB, Arguments_<'q>>, value: Value) } #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))] -pub fn bind_as_values<'q, O>(query: QueryAs<'q, DB, O, Arguments_<'q>>, values: &'q [Value]) -> QueryAs<'q, DB, O, Arguments_<'q>> { - values.into_iter().fold(query, |query, value| { - bind_value!(query, value) - }) +pub fn bind_as_values<'q, O>( + query: QueryAs<'q, DB, O, Arguments_<'q>>, + values: &'q [Value], +) -> QueryAs<'q, DB, O, Arguments_<'q>> { + values + .into_iter() + .fold(query, |query, value| bind_value!(query, value)) } #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))] -pub fn bind_scalar_values<'q, O>(query: QueryScalar<'q, DB, O, Arguments_<'q>>, values: &'q [Value]) -> QueryScalar<'q, DB, O, Arguments_<'q>> { +pub fn bind_scalar_values<'q, O>( + query: QueryScalar<'q, DB, O, Arguments_<'q>>, + values: &'q [Value], +) -> QueryScalar<'q, DB, O, Arguments_<'q>> { let mut query = query; for value in values { query = bind_value!(query, value); @@ -385,8 +407,11 @@ pub fn bind_scalar_values<'q, O>(query: QueryScalar<'q, DB, O, Arguments_<'q>>, #[inline] pub fn query_fields(fields: Vec<&str>) -> String { - fields.iter().filter_map(|e| e.split(" ").next()) - .collect::>().join(", ") + fields + .iter() + .filter_map(|e| e.split(" ").next()) + .collect::>() + .join(", ") } // From implementations for owned values @@ -656,9 +681,9 @@ impl<'q, O> BindValues<'q> for QueryAs<'q, DB, O, Arguments_<'q>> { type Output = QueryAs<'q, DB, O, Arguments_<'q>>; fn bind_values(self, values: &'q [Value]) -> Self::Output { - values.into_iter().fold(self, |query, value| { - bind_value!(query, value) - }) + values + .into_iter() + .fold(self, |query, value| bind_value!(query, value)) } } @@ -675,7 +700,6 @@ impl<'q, O> BindValues<'q> for QueryScalar<'q, DB, O, Arguments_<'q>> { } } - #[macro_export] macro_rules! values { () => {