Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ std = []
recursive-protection = ["std", "recursive"]
# Enable JSON output in the `cli` example:
json_example = ["serde_json", "serde"]
derive-dialect = ["sqlparser_derive"]
visitor = ["sqlparser_derive"]

[dependencies]
Expand All @@ -61,6 +62,10 @@ simple_logger = "5.0"
matches = "0.1"
pretty_assertions = "1"

[[test]]
name = "sqlparser_derive_dialect"
required-features = ["derive-dialect"]

[package.metadata.docs.rs]
# Document these features on docs.rs
features = ["serde", "visitor"]
features = ["serde", "visitor", "derive-dialect"]
2 changes: 1 addition & 1 deletion derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ edition = "2021"
proc-macro = true

[dependencies]
syn = { version = "2.0", default-features = false, features = ["printing", "parsing", "derive", "proc-macro"] }
syn = { version = "2.0", default-features = false, features = ["full", "printing", "parsing", "derive", "proc-macro", "clone-impls"] }
proc-macro2 = "1.0"
quote = "1.0"
305 changes: 305 additions & 0 deletions derive/src/dialect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Implementation of the `derive_dialect!` macro for creating custom SQL dialects.

use proc_macro2::TokenStream;
use quote::{quote, quote_spanned};
use std::collections::HashSet;
use syn::{
braced,
parse::{Parse, ParseStream},
Error, File, FnArg, Ident, Item, LitBool, LitChar, Pat, ReturnType, Signature, Token,
TraitItem, Type,
};

/// Override value types supported by the macro
pub(crate) enum Override {
Bool(LitBool),
Char(LitChar),
None,
}

/// Parsed input for the `derive_dialect!` macro
pub(crate) struct DeriveDialectInput {
pub name: Ident,
pub base: Type,
pub preserve_type_id: bool,
pub overrides: Vec<(Ident, Override)>,
}

/// `Dialect` trait method attrs
struct DialectMethod {
name: Ident,
signature: Signature,
}

impl Parse for DeriveDialectInput {
fn parse(input: ParseStream) -> syn::Result<Self> {
let name: Ident = input.parse()?;
input.parse::<Token![,]>()?;
let base: Type = input.parse()?;

let mut preserve_type_id = false;
let mut overrides = Vec::new();

while input.peek(Token![,]) {
input.parse::<Token![,]>()?;
if input.is_empty() {
break;
}
if input.peek(Ident) {
let ident: Ident = input.parse()?;
match ident.to_string().as_str() {
"preserve_type_id" => {
input.parse::<Token![=]>()?;
preserve_type_id = input.parse::<LitBool>()?.value();
}
"overrides" => {
input.parse::<Token![=]>()?;
let content;
braced!(content in input);
while !content.is_empty() {
let key: Ident = content.parse()?;
content.parse::<Token![=]>()?;
let value = if content.peek(LitBool) {
Override::Bool(content.parse()?)
} else if content.peek(LitChar) {
Override::Char(content.parse()?)
} else if content.peek(Ident) {
let ident: Ident = content.parse()?;
if ident == "None" {
Override::None
} else {
return Err(Error::new(
ident.span(),
format!("Expected `true`, `false`, a char, or `None`, found `{ident}`"),
));
}
} else {
return Err(
content.error("Expected `true`, `false`, a char, or `None`")
);
};
overrides.push((key, value));
if content.peek(Token![,]) {
content.parse::<Token![,]>()?;
}
}
}
other => {
return Err(Error::new(ident.span(), format!(
"Unknown argument `{other}`. Expected `preserve_type_id` or `overrides`."
)));
}
}
}
}
Ok(DeriveDialectInput {
name,
base,
preserve_type_id,
overrides,
})
}
}

/// Entry point for the `derive_dialect!` macro
pub(crate) fn derive_dialect(input: DeriveDialectInput) -> proc_macro::TokenStream {
let err = |msg: String| {
Error::new(proc_macro2::Span::call_site(), msg)
.to_compile_error()
.into()
};

let source = match read_dialect_mod_file() {
Ok(s) => s,
Err(e) => return err(format!("Failed to read dialect/mod.rs: {e}")),
};
let file: File = match syn::parse_str(&source) {
Ok(f) => f,
Err(e) => return err(format!("Failed to parse source: {e}")),
};
let methods = match extract_dialect_methods(&file) {
Ok(m) => m,
Err(e) => return e.to_compile_error().into(),
};

// Validate overrides
let bool_names: HashSet<_> = methods
.iter()
.filter(|m| is_bool_method(&m.signature))
.map(|m| m.name.to_string())
.collect();
for (key, value) in &input.overrides {
let key_str = key.to_string();
let err = |msg| Error::new(key.span(), msg).to_compile_error().into();
match value {
Override::Bool(_) if !bool_names.contains(&key_str) => {
return err(format!("Unknown boolean method `{key_str}`"));
}
Override::Char(_) | Override::None if key_str != "identifier_quote_style" => {
return err(format!(
"Char/None only valid for `identifier_quote_style`, not `{key_str}`"
));
}
_ => {}
}
}
generate_derived_dialect(&input, &methods).into()
}

/// Generate the complete derived `Dialect` implementation
fn generate_derived_dialect(input: &DeriveDialectInput, methods: &[DialectMethod]) -> TokenStream {
let name = &input.name;
let base = &input.base;

// Helper to find an override by method name
let find_override = |method_name: &str| {
input
.overrides
.iter()
.find(|(k, _)| k == method_name)
.map(|(_, v)| v)
};

// Helper to generate delegation to base dialect
let delegate = |method: &DialectMethod| {
let sig = &method.signature;
let method_name = &method.name;
let params = extract_param_names(sig);
quote_spanned! { method_name.span() => #sig { self.dialect.#method_name(#(#params),*) } }
};

// Generate the struct
let struct_def = quote_spanned! { name.span() =>
#[derive(Debug, Default)]
pub struct #name {
dialect: #base,
}
impl #name {
pub fn new() -> Self { Self::default() }
}
};

// Generate TypeId method body
let type_id_body = if input.preserve_type_id {
quote! { Dialect::dialect(&self.dialect) }
} else {
quote! { ::core::any::TypeId::of::<#name>() }
};

// Generate method implementations
let method_impls = methods.iter().map(|method| {
let method_name = &method.name;
match find_override(&method_name.to_string()) {
Some(Override::Bool(value)) => {
quote_spanned! { method_name.span() => fn #method_name(&self) -> bool { #value } }
}
Some(Override::Char(c)) => {
quote_spanned! { method_name.span() =>
fn identifier_quote_style(&self, _: &str) -> Option<char> { Some(#c) }
}
}
Some(Override::None) => {
quote_spanned! { method_name.span() =>
fn identifier_quote_style(&self, _: &str) -> Option<char> { None }
}
}
None => delegate(method),
}
});

// Wrap impl in a const block with scoped imports so types resolve without qualification
quote! {
#struct_def
const _: () = {
use ::core::iter::Peekable;
use ::core::str::Chars;
use sqlparser::ast::{ColumnOption, Expr, GranteesType, Ident, ObjectNamePart, Statement};
use sqlparser::dialect::{Dialect, Precedence};
use sqlparser::keywords::Keyword;
use sqlparser::parser::{Parser, ParserError};

impl Dialect for #name {
fn dialect(&self) -> ::core::any::TypeId { #type_id_body }
#(#method_impls)*
}
};
}
}

/// Extract parameter names from a method signature (excluding self)
fn extract_param_names(sig: &Signature) -> Vec<&Ident> {
sig.inputs
.iter()
.filter_map(|arg| match arg {
FnArg::Typed(pt) => match pt.pat.as_ref() {
Pat::Ident(pi) => Some(&pi.ident),
_ => None,
},
_ => None,
})
.collect()
}

/// Read the `dialect/mod.rs` file that contains the Dialect trait.
fn read_dialect_mod_file() -> Result<String, String> {
let manifest_dir =
std::env::var("CARGO_MANIFEST_DIR").map_err(|_| "CARGO_MANIFEST_DIR not set")?;
let path = std::path::Path::new(&manifest_dir).join("src/dialect/mod.rs");
std::fs::read_to_string(&path).map_err(|e| format!("Failed to read {}: {e}", path.display()))
}

/// Extract all methods from the `Dialect` trait (excluding `dialect` for TypeId)
fn extract_dialect_methods(file: &File) -> Result<Vec<DialectMethod>, Error> {
let dialect_trait = file
.items
.iter()
.find_map(|item| match item {
Item::Trait(t) if t.ident == "Dialect" => Some(t),
_ => None,
})
.ok_or_else(|| Error::new(proc_macro2::Span::call_site(), "Dialect trait not found"))?;

let mut methods: Vec<_> = dialect_trait
.items
.iter()
.filter_map(|item| match item {
TraitItem::Fn(m) if m.sig.ident != "dialect" => Some(DialectMethod {
name: m.sig.ident.clone(),
signature: m.sig.clone(),
}),
_ => None,
})
.collect();
methods.sort_by_key(|m| m.name.to_string());
Ok(methods)
}

/// Check if a method signature is `fn name(&self) -> bool`
fn is_bool_method(sig: &Signature) -> bool {
sig.inputs.len() == 1
&& matches!(
sig.inputs.first(),
Some(FnArg::Receiver(r)) if r.reference.is_some() && r.mutability.is_none()
)
&& matches!(
&sig.output,
ReturnType::Type(_, ty) if matches!(ty.as_ref(), Type::Path(p) if p.path.is_ident("bool"))
)
}
Loading