From 30a2cbbe5a0b7aded13c81ad2746c1b03a449bd7 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Thu, 8 Jan 2026 17:22:37 +0400 Subject: [PATCH 1/2] Streamlined derivation of new `Dialect` objects --- Cargo.toml | 7 +- derive/Cargo.toml | 2 +- derive/src/dialect.rs | 305 ++++++++++++++++++++++++++++++ derive/src/lib.rs | 276 +++------------------------ derive/src/visit.rs | 268 ++++++++++++++++++++++++++ src/dialect/ansi.rs | 2 +- src/dialect/clickhouse.rs | 2 +- src/dialect/hive.rs | 2 +- src/dialect/mod.rs | 108 ++++++++++- src/dialect/mssql.rs | 2 +- src/dialect/mysql.rs | 2 +- src/dialect/oracle.rs | 2 +- src/dialect/postgresql.rs | 2 +- src/dialect/redshift.rs | 2 +- src/dialect/sqlite.rs | 2 +- src/lib.rs | 3 + tests/sqlparser_derive_dialect.rs | 123 ++++++++++++ 17 files changed, 841 insertions(+), 269 deletions(-) create mode 100644 derive/src/dialect.rs create mode 100644 derive/src/visit.rs create mode 100644 tests/sqlparser_derive_dialect.rs diff --git a/Cargo.toml b/Cargo.toml index 177ab3db31..8945adef7e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,7 @@ std = [] recursive-protection = ["std", "recursive"] # Enable JSON output in the `cli` example: json_example = ["serde_json", "serde"] +derive-dialect = ["sqlparser_derive"] visitor = ["sqlparser_derive"] [dependencies] @@ -61,6 +62,10 @@ simple_logger = "5.0" matches = "0.1" pretty_assertions = "1" +[[test]] +name = "sqlparser_derive_dialect" +required-features = ["derive-dialect"] + [package.metadata.docs.rs] # Document these features on docs.rs -features = ["serde", "visitor"] +features = ["serde", "visitor", "derive-dialect"] diff --git a/derive/Cargo.toml b/derive/Cargo.toml index 549477041b..f2f54926b5 100644 --- a/derive/Cargo.toml +++ b/derive/Cargo.toml @@ -36,6 +36,6 @@ edition = "2021" proc-macro = true [dependencies] -syn = { version = "2.0", default-features = false, features = ["printing", "parsing", "derive", "proc-macro"] } +syn = { version = "2.0", default-features = false, features = ["full", "printing", "parsing", "derive", "proc-macro", "clone-impls"] } proc-macro2 = "1.0" quote = "1.0" diff --git a/derive/src/dialect.rs b/derive/src/dialect.rs new file mode 100644 index 0000000000..9873e4f7b5 --- /dev/null +++ b/derive/src/dialect.rs @@ -0,0 +1,305 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Implementation of the `derive_dialect!` macro for creating custom SQL dialects. + +use proc_macro2::TokenStream; +use quote::{quote, quote_spanned}; +use std::collections::HashSet; +use syn::{ + braced, + parse::{Parse, ParseStream}, + Error, File, FnArg, Ident, Item, LitBool, LitChar, Pat, ReturnType, Signature, Token, + TraitItem, Type, +}; + +/// Override value types supported by the macro +pub(crate) enum Override { + Bool(LitBool), + Char(LitChar), + None, +} + +/// Parsed input for the `derive_dialect!` macro +pub(crate) struct DeriveDialectInput { + pub name: Ident, + pub base: Type, + pub preserve_type_id: bool, + pub overrides: Vec<(Ident, Override)>, +} + +/// `Dialect` trait method attrs +struct DialectMethod { + name: Ident, + signature: Signature, +} + +impl Parse for DeriveDialectInput { + fn parse(input: ParseStream) -> syn::Result { + let name: Ident = input.parse()?; + input.parse::()?; + let base: Type = input.parse()?; + + let mut preserve_type_id = false; + let mut overrides = Vec::new(); + + while input.peek(Token![,]) { + input.parse::()?; + if input.is_empty() { + break; + } + if input.peek(Ident) { + let ident: Ident = input.parse()?; + match ident.to_string().as_str() { + "preserve_type_id" => { + input.parse::()?; + preserve_type_id = input.parse::()?.value(); + } + "overrides" => { + input.parse::()?; + let content; + braced!(content in input); + while !content.is_empty() { + let key: Ident = content.parse()?; + content.parse::()?; + let value = if content.peek(LitBool) { + Override::Bool(content.parse()?) + } else if content.peek(LitChar) { + Override::Char(content.parse()?) + } else if content.peek(Ident) { + let ident: Ident = content.parse()?; + if ident == "None" { + Override::None + } else { + return Err(Error::new( + ident.span(), + format!("Expected `true`, `false`, a char, or `None`, found `{ident}`"), + )); + } + } else { + return Err( + content.error("Expected `true`, `false`, a char, or `None`") + ); + }; + overrides.push((key, value)); + if content.peek(Token![,]) { + content.parse::()?; + } + } + } + other => { + return Err(Error::new(ident.span(), format!( + "Unknown argument `{other}`. Expected `preserve_type_id` or `overrides`." + ))); + } + } + } + } + Ok(DeriveDialectInput { + name, + base, + preserve_type_id, + overrides, + }) + } +} + +/// Entry point for the `derive_dialect!` macro +pub(crate) fn derive_dialect(input: DeriveDialectInput) -> proc_macro::TokenStream { + let err = |msg: String| { + Error::new(proc_macro2::Span::call_site(), msg) + .to_compile_error() + .into() + }; + + let source = match read_dialect_mod_file() { + Ok(s) => s, + Err(e) => return err(format!("Failed to read dialect/mod.rs: {e}")), + }; + let file: File = match syn::parse_str(&source) { + Ok(f) => f, + Err(e) => return err(format!("Failed to parse source: {e}")), + }; + let methods = match extract_dialect_methods(&file) { + Ok(m) => m, + Err(e) => return e.to_compile_error().into(), + }; + + // Validate overrides + let bool_names: HashSet<_> = methods + .iter() + .filter(|m| is_bool_method(&m.signature)) + .map(|m| m.name.to_string()) + .collect(); + for (key, value) in &input.overrides { + let key_str = key.to_string(); + let err = |msg| Error::new(key.span(), msg).to_compile_error().into(); + match value { + Override::Bool(_) if !bool_names.contains(&key_str) => { + return err(format!("Unknown boolean method `{key_str}`")); + } + Override::Char(_) | Override::None if key_str != "identifier_quote_style" => { + return err(format!( + "Char/None only valid for `identifier_quote_style`, not `{key_str}`" + )); + } + _ => {} + } + } + generate_derived_dialect(&input, &methods).into() +} + +/// Generate the complete derived `Dialect` implementation +fn generate_derived_dialect(input: &DeriveDialectInput, methods: &[DialectMethod]) -> TokenStream { + let name = &input.name; + let base = &input.base; + + // Helper to find an override by method name + let find_override = |method_name: &str| { + input + .overrides + .iter() + .find(|(k, _)| k == method_name) + .map(|(_, v)| v) + }; + + // Helper to generate delegation to base dialect + let delegate = |method: &DialectMethod| { + let sig = &method.signature; + let method_name = &method.name; + let params = extract_param_names(sig); + quote_spanned! { method_name.span() => #sig { self.dialect.#method_name(#(#params),*) } } + }; + + // Generate the struct + let struct_def = quote_spanned! { name.span() => + #[derive(Debug, Default)] + pub struct #name { + dialect: #base, + } + impl #name { + pub fn new() -> Self { Self::default() } + } + }; + + // Generate TypeId method body + let type_id_body = if input.preserve_type_id { + quote! { Dialect::dialect(&self.dialect) } + } else { + quote! { ::core::any::TypeId::of::<#name>() } + }; + + // Generate method implementations + let method_impls = methods.iter().map(|method| { + let method_name = &method.name; + match find_override(&method_name.to_string()) { + Some(Override::Bool(value)) => { + quote_spanned! { method_name.span() => fn #method_name(&self) -> bool { #value } } + } + Some(Override::Char(c)) => { + quote_spanned! { method_name.span() => + fn identifier_quote_style(&self, _: &str) -> Option { Some(#c) } + } + } + Some(Override::None) => { + quote_spanned! { method_name.span() => + fn identifier_quote_style(&self, _: &str) -> Option { None } + } + } + None => delegate(method), + } + }); + + // Wrap impl in a const block with scoped imports so types resolve without qualification + quote! { + #struct_def + const _: () = { + use ::core::iter::Peekable; + use ::core::str::Chars; + use sqlparser::ast::{ColumnOption, Expr, GranteesType, Ident, ObjectNamePart, Statement}; + use sqlparser::dialect::{Dialect, Precedence}; + use sqlparser::keywords::Keyword; + use sqlparser::parser::{Parser, ParserError}; + + impl Dialect for #name { + fn dialect(&self) -> ::core::any::TypeId { #type_id_body } + #(#method_impls)* + } + }; + } +} + +/// Extract parameter names from a method signature (excluding self) +fn extract_param_names(sig: &Signature) -> Vec<&Ident> { + sig.inputs + .iter() + .filter_map(|arg| match arg { + FnArg::Typed(pt) => match pt.pat.as_ref() { + Pat::Ident(pi) => Some(&pi.ident), + _ => None, + }, + _ => None, + }) + .collect() +} + +/// Read the `dialect/mod.rs` file that contains the Dialect trait. +fn read_dialect_mod_file() -> Result { + let manifest_dir = + std::env::var("CARGO_MANIFEST_DIR").map_err(|_| "CARGO_MANIFEST_DIR not set")?; + let path = std::path::Path::new(&manifest_dir).join("src/dialect/mod.rs"); + std::fs::read_to_string(&path).map_err(|e| format!("Failed to read {}: {e}", path.display())) +} + +/// Extract all methods from the `Dialect` trait (excluding `dialect` for TypeId) +fn extract_dialect_methods(file: &File) -> Result, Error> { + let dialect_trait = file + .items + .iter() + .find_map(|item| match item { + Item::Trait(t) if t.ident == "Dialect" => Some(t), + _ => None, + }) + .ok_or_else(|| Error::new(proc_macro2::Span::call_site(), "Dialect trait not found"))?; + + let mut methods: Vec<_> = dialect_trait + .items + .iter() + .filter_map(|item| match item { + TraitItem::Fn(m) if m.sig.ident != "dialect" => Some(DialectMethod { + name: m.sig.ident.clone(), + signature: m.sig.clone(), + }), + _ => None, + }) + .collect(); + methods.sort_by_key(|m| m.name.to_string()); + Ok(methods) +} + +/// Check if a method signature is `fn name(&self) -> bool` +fn is_bool_method(sig: &Signature) -> bool { + sig.inputs.len() == 1 + && matches!( + sig.inputs.first(), + Some(FnArg::Receiver(r)) if r.reference.is_some() && r.mutability.is_none() + ) + && matches!( + &sig.output, + ReturnType::Type(_, ty) if matches!(ty.as_ref(), Type::Path(p) if p.path.is_ident("bool")) + ) +} diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 08c5c5db4b..e3eaeea6d5 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -15,22 +15,25 @@ // specific language governing permissions and limitations // under the License. -use proc_macro2::TokenStream; -use quote::{format_ident, quote, quote_spanned, ToTokens}; -use syn::spanned::Spanned; -use syn::{ - parse::{Parse, ParseStream}, - parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics, - Ident, Index, LitStr, Meta, Token, Type, TypePath, -}; -use syn::{Path, PathArguments}; +//! Procedural macros for sqlparser. +//! +//! This crate provides: +//! - [`Visit`] and [`VisitMut`] derive macros for AST traversal. +//! - [`derive_dialect!`] macro for creating custom SQL dialects. -/// Implementation of `[#derive(Visit)]` +use quote::quote; +use syn::parse_macro_input; + +mod dialect; +mod visit; + +/// Implementation of `#[derive(VisitMut)]` #[proc_macro_derive(VisitMut, attributes(visit))] pub fn derive_visit_mut(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - derive_visit( + let input = parse_macro_input!(input as syn::DeriveInput); + visit::derive_visit( input, - &VisitType { + &visit::VisitType { visit_trait: quote!(VisitMut), visitor_trait: quote!(VisitorMut), modifier: Some(quote!(mut)), @@ -38,12 +41,13 @@ pub fn derive_visit_mut(input: proc_macro::TokenStream) -> proc_macro::TokenStre ) } -/// Implementation of `[#derive(Visit)]` +/// Implementation of `#[derive(Visit)]` #[proc_macro_derive(Visit, attributes(visit))] pub fn derive_visit_immutable(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - derive_visit( + let input = parse_macro_input!(input as syn::DeriveInput); + visit::derive_visit( input, - &VisitType { + &visit::VisitType { visit_trait: quote!(Visit), visitor_trait: quote!(Visitor), modifier: None, @@ -51,241 +55,9 @@ pub fn derive_visit_immutable(input: proc_macro::TokenStream) -> proc_macro::Tok ) } -struct VisitType { - visit_trait: TokenStream, - visitor_trait: TokenStream, - modifier: Option, -} - -fn derive_visit(input: proc_macro::TokenStream, visit_type: &VisitType) -> proc_macro::TokenStream { - // Parse the input tokens into a syntax tree. - let input = parse_macro_input!(input as DeriveInput); - let name = input.ident; - - let VisitType { - visit_trait, - visitor_trait, - modifier, - } = visit_type; - - let attributes = Attributes::parse(&input.attrs); - // Add a bound `T: Visit` to every type parameter T. - let generics = add_trait_bounds(input.generics, visit_type); - let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - - let (pre_visit, post_visit) = attributes.visit(quote!(self)); - let children = visit_children(&input.data, visit_type); - - let expanded = quote! { - // The generated impl. - // Note that it uses [`recursive::recursive`] to protect from stack overflow. - // See tests in https://github.com/apache/datafusion-sqlparser-rs/pull/1522/ for more info. - impl #impl_generics sqlparser::ast::#visit_trait for #name #ty_generics #where_clause { - #[cfg_attr(feature = "recursive-protection", recursive::recursive)] - fn visit( - &#modifier self, - visitor: &mut V - ) -> ::std::ops::ControlFlow { - #pre_visit - #children - #post_visit - ::std::ops::ControlFlow::Continue(()) - } - } - }; - - proc_macro::TokenStream::from(expanded) -} - -/// Parses attributes that can be provided to this macro -/// -/// `#[visit(leaf, with = "visit_expr")]` -#[derive(Default)] -struct Attributes { - /// Content for the `with` attribute - with: Option, -} - -struct WithIdent { - with: Option, -} -impl Parse for WithIdent { - fn parse(input: ParseStream) -> Result { - let mut result = WithIdent { with: None }; - let ident = input.parse::()?; - if ident != "with" { - return Err(syn::Error::new( - ident.span(), - "Expected identifier to be `with`", - )); - } - input.parse::()?; - let s = input.parse::()?; - result.with = Some(format_ident!("{}", s.value(), span = s.span())); - Ok(result) - } -} - -impl Attributes { - fn parse(attrs: &[Attribute]) -> Self { - let mut out = Self::default(); - for attr in attrs { - if let Meta::List(ref metalist) = attr.meta { - if metalist.path.is_ident("visit") { - match syn::parse2::(metalist.tokens.clone()) { - Ok(with_ident) => { - out.with = with_ident.with; - } - Err(e) => { - panic!("{}", e); - } - } - } - } - } - out - } - - /// Returns the pre and post visit token streams - fn visit(&self, s: TokenStream) -> (Option, Option) { - let pre_visit = self.with.as_ref().map(|m| { - let m = format_ident!("pre_{}", m); - quote!(visitor.#m(#s)?;) - }); - let post_visit = self.with.as_ref().map(|m| { - let m = format_ident!("post_{}", m); - quote!(visitor.#m(#s)?;) - }); - (pre_visit, post_visit) - } -} - -// Add a bound `T: Visit` to every type parameter T. -fn add_trait_bounds(mut generics: Generics, VisitType { visit_trait, .. }: &VisitType) -> Generics { - for param in &mut generics.params { - if let GenericParam::Type(ref mut type_param) = *param { - type_param - .bounds - .push(parse_quote!(sqlparser::ast::#visit_trait)); - } - } - generics -} - -// Generate the body of the visit implementation for the given type -fn visit_children( - data: &Data, - VisitType { - visit_trait, - modifier, - .. - }: &VisitType, -) -> TokenStream { - match data { - Data::Struct(data) => match &data.fields { - Fields::Named(fields) => { - let recurse = fields.named.iter().map(|f| { - let name = &f.ident; - let is_option = is_option(&f.ty); - let attributes = Attributes::parse(&f.attrs); - if is_option && attributes.with.is_some() { - let (pre_visit, post_visit) = attributes.visit(quote!(value)); - quote_spanned!(f.span() => - if let Some(value) = &#modifier self.#name { - #pre_visit sqlparser::ast::#visit_trait::visit(value, visitor)?; #post_visit - } - ) - } else { - let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name)); - quote_spanned!(f.span() => - #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit - ) - } - }); - quote! { - #(#recurse)* - } - } - Fields::Unnamed(fields) => { - let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| { - let index = Index::from(i); - let attributes = Attributes::parse(&f.attrs); - let (pre_visit, post_visit) = attributes.visit(quote!(&self.#index)); - quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#index, visitor)?; #post_visit) - }); - quote! { - #(#recurse)* - } - } - Fields::Unit => { - quote!() - } - }, - Data::Enum(data) => { - let statements = data.variants.iter().map(|v| { - let name = &v.ident; - match &v.fields { - Fields::Named(fields) => { - let names = fields.named.iter().map(|f| &f.ident); - let visit = fields.named.iter().map(|f| { - let name = &f.ident; - let attributes = Attributes::parse(&f.attrs); - let (pre_visit, post_visit) = attributes.visit(name.to_token_stream()); - quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit) - }); - - quote!( - Self::#name { #(#names),* } => { - #(#visit)* - } - ) - } - Fields::Unnamed(fields) => { - let names = fields.unnamed.iter().enumerate().map(|(i, f)| format_ident!("_{}", i, span = f.span())); - let visit = fields.unnamed.iter().enumerate().map(|(i, f)| { - let name = format_ident!("_{}", i); - let attributes = Attributes::parse(&f.attrs); - let (pre_visit, post_visit) = attributes.visit(name.to_token_stream()); - quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit) - }); - - quote! { - Self::#name ( #(#names),*) => { - #(#visit)* - } - } - } - Fields::Unit => { - quote! { - Self::#name => {} - } - } - } - }); - - quote! { - match self { - #(#statements),* - } - } - } - Data::Union(_) => unimplemented!(), - } -} - -fn is_option(ty: &Type) -> bool { - if let Type::Path(TypePath { - path: Path { segments, .. }, - .. - }) = ty - { - if let Some(segment) = segments.last() { - if segment.ident == "Option" { - if let PathArguments::AngleBracketed(args) = &segment.arguments { - return args.args.len() == 1; - } - } - } - } - false +/// Procedural macro for deriving new SQL dialects. +#[proc_macro] +pub fn derive_dialect(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(input as dialect::DeriveDialectInput); + dialect::derive_dialect(input) } diff --git a/derive/src/visit.rs b/derive/src/visit.rs new file mode 100644 index 0000000000..baf3eb583b --- /dev/null +++ b/derive/src/visit.rs @@ -0,0 +1,268 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Implementation of the `Visit` and `VisitMut` derive macros. + +use proc_macro2::TokenStream; +use quote::{format_ident, quote, quote_spanned, ToTokens}; +use syn::spanned::Spanned; +use syn::{ + parse::{Parse, ParseStream}, + parse_quote, Attribute, Data, Fields, GenericParam, Generics, Ident, Index, LitStr, Meta, + Token, Type, TypePath, +}; +use syn::{Path, PathArguments}; + +pub(crate) struct VisitType { + pub visit_trait: TokenStream, + pub visitor_trait: TokenStream, + pub modifier: Option, +} + +pub(crate) fn derive_visit( + input: syn::DeriveInput, + visit_type: &VisitType, +) -> proc_macro::TokenStream { + let name = input.ident; + + let VisitType { + visit_trait, + visitor_trait, + modifier, + } = visit_type; + + let attributes = Attributes::parse(&input.attrs); + // Add a bound `T: Visit` to every type parameter T. + let generics = add_trait_bounds(input.generics, visit_type); + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let (pre_visit, post_visit) = attributes.visit(quote!(self)); + let children = visit_children(&input.data, visit_type); + + let expanded = quote! { + // The generated impl. + // Note that it uses [`recursive::recursive`] to protect from stack overflow. + // See tests in https://github.com/apache/datafusion-sqlparser-rs/pull/1522/ for more info. + impl #impl_generics sqlparser::ast::#visit_trait for #name #ty_generics #where_clause { + #[cfg_attr(feature = "recursive-protection", recursive::recursive)] + fn visit( + &#modifier self, + visitor: &mut V + ) -> ::std::ops::ControlFlow { + #pre_visit + #children + #post_visit + ::std::ops::ControlFlow::Continue(()) + } + } + }; + + proc_macro::TokenStream::from(expanded) +} + +/// Parses attributes that can be provided to this macro +/// +/// `#[visit(leaf, with = "visit_expr")]` +#[derive(Default)] +struct Attributes { + /// Content for the `with` attribute + with: Option, +} + +struct WithIdent { + with: Option, +} +impl Parse for WithIdent { + fn parse(input: ParseStream) -> Result { + let mut result = WithIdent { with: None }; + let ident = input.parse::()?; + if ident != "with" { + return Err(syn::Error::new( + ident.span(), + "Expected identifier to be `with`", + )); + } + input.parse::()?; + let s = input.parse::()?; + result.with = Some(format_ident!("{}", s.value(), span = s.span())); + Ok(result) + } +} + +impl Attributes { + fn parse(attrs: &[Attribute]) -> Self { + let mut out = Self::default(); + for attr in attrs { + if let Meta::List(ref metalist) = attr.meta { + if metalist.path.is_ident("visit") { + match syn::parse2::(metalist.tokens.clone()) { + Ok(with_ident) => { + out.with = with_ident.with; + } + Err(e) => { + panic!("{}", e); + } + } + } + } + } + out + } + + /// Returns the pre and post visit token streams + fn visit(&self, s: TokenStream) -> (Option, Option) { + let pre_visit = self.with.as_ref().map(|m| { + let m = format_ident!("pre_{}", m); + quote!(visitor.#m(#s)?;) + }); + let post_visit = self.with.as_ref().map(|m| { + let m = format_ident!("post_{}", m); + quote!(visitor.#m(#s)?;) + }); + (pre_visit, post_visit) + } +} + +// Add a bound `T: Visit` to every type parameter T. +fn add_trait_bounds(mut generics: Generics, VisitType { visit_trait, .. }: &VisitType) -> Generics { + for param in &mut generics.params { + if let GenericParam::Type(ref mut type_param) = *param { + type_param + .bounds + .push(parse_quote!(sqlparser::ast::#visit_trait)); + } + } + generics +} + +// Generate the body of the visit implementation for the given type +fn visit_children( + data: &Data, + VisitType { + visit_trait, + modifier, + .. + }: &VisitType, +) -> TokenStream { + match data { + Data::Struct(data) => match &data.fields { + Fields::Named(fields) => { + let recurse = fields.named.iter().map(|f| { + let name = &f.ident; + let is_option = is_option(&f.ty); + let attributes = Attributes::parse(&f.attrs); + if is_option && attributes.with.is_some() { + let (pre_visit, post_visit) = attributes.visit(quote!(value)); + quote_spanned!(f.span() => + if let Some(value) = &#modifier self.#name { + #pre_visit sqlparser::ast::#visit_trait::visit(value, visitor)?; #post_visit + } + ) + } else { + let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name)); + quote_spanned!(f.span() => + #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit + ) + } + }); + quote! { + #(#recurse)* + } + } + Fields::Unnamed(fields) => { + let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| { + let index = Index::from(i); + let attributes = Attributes::parse(&f.attrs); + let (pre_visit, post_visit) = attributes.visit(quote!(&self.#index)); + quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#index, visitor)?; #post_visit) + }); + quote! { + #(#recurse)* + } + } + Fields::Unit => { + quote!() + } + }, + Data::Enum(data) => { + let statements = data.variants.iter().map(|v| { + let name = &v.ident; + match &v.fields { + Fields::Named(fields) => { + let names = fields.named.iter().map(|f| &f.ident); + let visit = fields.named.iter().map(|f| { + let name = &f.ident; + let attributes = Attributes::parse(&f.attrs); + let (pre_visit, post_visit) = attributes.visit(name.to_token_stream()); + quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit) + }); + + quote!( + Self::#name { #(#names),* } => { + #(#visit)* + } + ) + } + Fields::Unnamed(fields) => { + let names = fields.unnamed.iter().enumerate().map(|(i, f)| format_ident!("_{}", i, span = f.span())); + let visit = fields.unnamed.iter().enumerate().map(|(i, f)| { + let name = format_ident!("_{}", i); + let attributes = Attributes::parse(&f.attrs); + let (pre_visit, post_visit) = attributes.visit(name.to_token_stream()); + quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit) + }); + + quote! { + Self::#name ( #(#names),*) => { + #(#visit)* + } + } + } + Fields::Unit => { + quote! { + Self::#name => {} + } + } + } + }); + + quote! { + match self { + #(#statements),* + } + } + } + Data::Union(_) => unimplemented!(), + } +} + +fn is_option(ty: &Type) -> bool { + if let Type::Path(TypePath { + path: Path { segments, .. }, + .. + }) = ty + { + if let Some(segment) = segments.last() { + if segment.ident == "Option" { + if let PathArguments::AngleBracketed(args) = &segment.arguments { + return args.args.len() == 1; + } + } + } + } + false +} diff --git a/src/dialect/ansi.rs b/src/dialect/ansi.rs index ec3c095be5..5a54390cfd 100644 --- a/src/dialect/ansi.rs +++ b/src/dialect/ansi.rs @@ -18,7 +18,7 @@ use crate::dialect::Dialect; /// A [`Dialect`] for [ANSI SQL](https://en.wikipedia.org/wiki/SQL:2011). -#[derive(Debug)] +#[derive(Debug, Default)] pub struct AnsiDialect {} impl Dialect for AnsiDialect { diff --git a/src/dialect/clickhouse.rs b/src/dialect/clickhouse.rs index 39e8a0b304..2bd7201bf9 100644 --- a/src/dialect/clickhouse.rs +++ b/src/dialect/clickhouse.rs @@ -18,7 +18,7 @@ use crate::dialect::Dialect; /// A [`Dialect`] for [ClickHouse](https://clickhouse.com/). -#[derive(Debug)] +#[derive(Debug, Default)] pub struct ClickHouseDialect {} impl Dialect for ClickHouseDialect { diff --git a/src/dialect/hive.rs b/src/dialect/hive.rs index 3e15d395b1..32a982e907 100644 --- a/src/dialect/hive.rs +++ b/src/dialect/hive.rs @@ -18,7 +18,7 @@ use crate::dialect::Dialect; /// A [`Dialect`] for [Hive](https://hive.apache.org/). -#[derive(Debug)] +#[derive(Debug, Default)] pub struct HiveDialect {} impl Dialect for HiveDialect { diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 284fc41726..56acf4ef3f 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -51,6 +51,82 @@ pub use self::postgresql::PostgreSqlDialect; pub use self::redshift::RedshiftSqlDialect; pub use self::snowflake::SnowflakeDialect; pub use self::sqlite::SQLiteDialect; + +/// Macro for streamlining the creation of derived `Dialect` objects. +/// The generated struct includes `new()` and `default()` constructors. +/// Requires the `derive-dialect` feature. +/// +/// # Syntax +/// +/// ```text +/// derive_dialect!(NewDialect, BaseDialect); +/// derive_dialect!(NewDialect, BaseDialect, overrides = { method = value, ... }); +/// derive_dialect!(NewDialect, BaseDialect, preserve_type_id = true); +/// derive_dialect!(NewDialect, BaseDialect, preserve_type_id = true, overrides = { ... }); +/// ``` +/// +/// # Example +/// +/// ``` +/// use sqlparser::derive_dialect; +/// use sqlparser::dialect::{Dialect, GenericDialect}; +/// +/// // Override boolean methods (supports_*, allow_*, etc.) +/// derive_dialect!(CustomDialect, GenericDialect, overrides = { +/// supports_order_by_all = true, +/// supports_nested_comments = true, +/// }); +/// +/// let dialect = CustomDialect::new(); +/// assert!(dialect.supports_order_by_all()); +/// assert!(dialect.supports_nested_comments()); +/// ``` +/// +/// # Overriding `identifier_quote_style` +/// +/// Use a char literal or `None`: +/// ``` +/// use sqlparser::derive_dialect; +/// use sqlparser::dialect::{Dialect, PostgreSqlDialect}; +/// +/// derive_dialect!(BacktickPostgreSqlDialect, PostgreSqlDialect, +/// preserve_type_id = true, +/// overrides = { identifier_quote_style = '`' } +/// ); +/// let d: &dyn Dialect = &BacktickPostgreSqlDialect::new(); +/// assert_eq!(d.identifier_quote_style("foo"), Some('`')); +/// +/// derive_dialect!(QuotelessPostgreSqlDialect, PostgreSqlDialect, +/// preserve_type_id = true, +/// overrides = { identifier_quote_style = None } +/// ); +/// let d: &dyn Dialect = &QuotelessPostgreSqlDialect::new(); +/// assert_eq!(d.identifier_quote_style("foo"), None); +/// ``` +/// +/// # Type Identity +/// +/// By default, derived dialects have their own `TypeId`. Set `preserve_type_id = true` to +/// retain the base dialect's identity with respect to the parser's `dialect.is::()` checks: +/// ``` +/// use sqlparser::derive_dialect; +/// use sqlparser::dialect::{Dialect, GenericDialect}; +/// +/// derive_dialect!(EnhancedGenericDialect, GenericDialect, +/// preserve_type_id = true, +/// overrides = { +/// supports_order_by_all = true, +/// supports_nested_comments = true, +/// } +/// ); +/// let d: &dyn Dialect = &EnhancedGenericDialect::new(); +/// assert!(d.is::()); // still recognized as a GenericDialect +/// assert!(d.supports_nested_comments()); +/// assert!(d.supports_order_by_all()); +/// ``` +#[cfg(feature = "derive-dialect")] +pub use sqlparser_derive::derive_dialect; + use crate::ast::{ColumnOption, Expr, GranteesType, Ident, ObjectNamePart, Statement}; pub use crate::keywords; use crate::keywords::Keyword; @@ -62,14 +138,14 @@ use alloc::boxed::Box; /// Convenience check if a [`Parser`] uses a certain dialect. /// -/// Note: when possible please the new style, adding a method to the [`Dialect`] -/// trait rather than using this macro. +/// Note: when possible, please use the new style, adding a method to +/// the [`Dialect`] trait rather than using this macro. /// /// The benefits of adding a method on `Dialect` over this macro are: /// 1. user defined [`Dialect`]s can customize the parsing behavior /// 2. The differences between dialects can be clearly documented in the trait /// -/// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates +/// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates /// to `true` if `parser.dialect` is one of the [`Dialect`]s specified. macro_rules! dialect_of { ( $parsed_dialect: ident is $($dialect_type: ty)|+ ) => { @@ -123,9 +199,8 @@ macro_rules! dialect_is { pub trait Dialect: Debug + Any { /// Determine the [`TypeId`] of this dialect. /// - /// By default, return the same [`TypeId`] as [`Any::type_id`]. Can be overridden - /// by dialects that behave like other dialects - /// (for example when wrapping a dialect). + /// By default, return the same [`TypeId`] as [`Any::type_id`]. Can be overridden by + /// dialects that behave like other dialects (for example, when wrapping a dialect). fn dialect(&self) -> TypeId { self.type_id() } @@ -1470,6 +1545,27 @@ mod tests { dialect_from_str(v).unwrap() } + #[test] + #[cfg(feature = "derive-dialect")] + fn test_dialect_override() { + derive_dialect!(EnhancedGenericDialect, GenericDialect, + preserve_type_id = true, + overrides = { + supports_order_by_all = true, + supports_nested_comments = true, + supports_triple_quoted_string = true, + }, + ); + let dialect = EnhancedGenericDialect::new(); + + assert!(dialect.supports_order_by_all()); + assert!(dialect.supports_nested_comments()); + assert!(dialect.supports_triple_quoted_string()); + + let d: &dyn Dialect = &dialect; + assert!(d.is::()); + } + #[test] fn identifier_quote_style() { let tests: Vec<(&dyn Dialect, &str, Option)> = vec![ diff --git a/src/dialect/mssql.rs b/src/dialect/mssql.rs index 9f8e726562..24f7c7c4f3 100644 --- a/src/dialect/mssql.rs +++ b/src/dialect/mssql.rs @@ -28,7 +28,7 @@ use crate::tokenizer::Token; use alloc::{vec, vec::Vec}; /// A [`Dialect`] for [Microsoft SQL Server](https://www.microsoft.com/en-us/sql-server/) -#[derive(Debug)] +#[derive(Debug, Default)] pub struct MsSqlDialect {} impl Dialect for MsSqlDialect { diff --git a/src/dialect/mysql.rs b/src/dialect/mysql.rs index 81aa9d445a..67c8d7f648 100644 --- a/src/dialect/mysql.rs +++ b/src/dialect/mysql.rs @@ -35,7 +35,7 @@ const RESERVED_FOR_TABLE_ALIAS_MYSQL: &[Keyword] = &[ ]; /// A [`Dialect`] for [MySQL](https://www.mysql.com/) -#[derive(Debug)] +#[derive(Debug, Default)] pub struct MySqlDialect {} impl Dialect for MySqlDialect { diff --git a/src/dialect/oracle.rs b/src/dialect/oracle.rs index 54c2ace5fb..3f01240c7d 100644 --- a/src/dialect/oracle.rs +++ b/src/dialect/oracle.rs @@ -25,7 +25,7 @@ use crate::{ use super::{Dialect, Precedence}; /// A [`Dialect`] for [Oracle Databases](https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/index.html) -#[derive(Debug)] +#[derive(Debug, Default)] pub struct OracleDialect; impl Dialect for OracleDialect { diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs index 7c9e7db86c..1924a5e313 100644 --- a/src/dialect/postgresql.rs +++ b/src/dialect/postgresql.rs @@ -34,7 +34,7 @@ use crate::parser::{Parser, ParserError}; use crate::tokenizer::Token; /// A [`Dialect`] for [PostgreSQL](https://www.postgresql.org/) -#[derive(Debug)] +#[derive(Debug, Default)] pub struct PostgreSqlDialect {} const PERIOD_PREC: u8 = 200; diff --git a/src/dialect/redshift.rs b/src/dialect/redshift.rs index 43c0646ce3..9c03148790 100644 --- a/src/dialect/redshift.rs +++ b/src/dialect/redshift.rs @@ -22,7 +22,7 @@ use core::str::Chars; use super::PostgreSqlDialect; /// A [`Dialect`] for [RedShift](https://aws.amazon.com/redshift/) -#[derive(Debug)] +#[derive(Debug, Default)] pub struct RedshiftSqlDialect {} // In most cases the redshift dialect is identical to [`PostgresSqlDialect`]. diff --git a/src/dialect/sqlite.rs b/src/dialect/sqlite.rs index ba4cb6173a..7d1c935f16 100644 --- a/src/dialect/sqlite.rs +++ b/src/dialect/sqlite.rs @@ -30,7 +30,7 @@ use crate::parser::{Parser, ParserError}; /// [`CREATE TABLE`](https://sqlite.org/lang_createtable.html) statement with no /// type specified, as in `CREATE TABLE t1 (a)`. In the AST, these columns will /// have the data type [`Unspecified`](crate::ast::DataType::Unspecified). -#[derive(Debug)] +#[derive(Debug, Default)] pub struct SQLiteDialect {} impl Dialect for SQLiteDialect { diff --git a/src/lib.rs b/src/lib.rs index f5d23a21fc..e68d7f93eb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -170,6 +170,9 @@ pub mod ast; #[macro_use] /// Submodules for SQL dialects. pub mod dialect; + +#[cfg(feature = "derive-dialect")] +pub use dialect::derive_dialect; mod display_utils; pub mod keywords; pub mod parser; diff --git a/tests/sqlparser_derive_dialect.rs b/tests/sqlparser_derive_dialect.rs new file mode 100644 index 0000000000..d60fa1e11d --- /dev/null +++ b/tests/sqlparser_derive_dialect.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for the `derive_dialect!` macro. + +use sqlparser::derive_dialect; +use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect, PostgreSqlDialect}; +use sqlparser::parser::Parser; + +#[test] +fn test_method_overrides() { + derive_dialect!(EnhancedGenericDialect, GenericDialect, overrides = { + supports_order_by_all = true, + supports_triple_quoted_string = true, + }); + let dialect = EnhancedGenericDialect::new(); + + // Overridden methods + assert!(dialect.supports_order_by_all()); + assert!(dialect.supports_triple_quoted_string()); + + // Non-overridden retains base behavior + assert!(!dialect.supports_factorial_operator()); + + // Parsing works with the overrides + let result = Parser::new(&dialect) + .try_with_sql("SELECT '''value''' FROM t ORDER BY ALL") + .unwrap() + .parse_statements(); + + assert!(result.is_ok()); +} + +#[test] +fn test_preserve_type_id() { + // Check the override works and the parser recognizes it as the base type + derive_dialect!( + PreservedTypeDialect, + GenericDialect, + preserve_type_id = true, + overrides = { supports_order_by_all = true } + ); + let dialect = PreservedTypeDialect::new(); + let d: &dyn Dialect = &dialect; + + assert!(dialect.supports_order_by_all()); + assert!(d.is::()); +} + +#[test] +fn test_different_base_dialects() { + derive_dialect!( + EnhancedMySqlDialect, + MySqlDialect, + overrides = { supports_order_by_all = true } + ); + derive_dialect!(UniquePostgreSqlDialect, PostgreSqlDialect); + + let pg = UniquePostgreSqlDialect::new(); + let mysql = EnhancedMySqlDialect::new(); + + // Inherit different base behaviors + assert!(pg.supports_filter_during_aggregation()); // PostgreSQL feature + assert!(mysql.supports_string_literal_backslash_escape()); // MySQL feature + assert!(mysql.supports_order_by_all()); // Override + + // Each has unique TypeId + let pg_ref: &dyn Dialect = &pg; + let mysql_ref: &dyn Dialect = &mysql; + assert!(pg_ref.is::()); + assert!(!pg_ref.is::()); + assert!(mysql_ref.is::()); +} + +#[test] +fn test_identifier_quote_style_overrides() { + derive_dialect!( + BacktickGenericDialect, + GenericDialect, + overrides = { identifier_quote_style = '`' } + ); + derive_dialect!( + AnotherBacktickDialect, + GenericDialect, + overrides = { identifier_quote_style = '[' } + ); + derive_dialect!( + QuotelessPostgreSqlDialect, + PostgreSqlDialect, + preserve_type_id = true, + overrides = { identifier_quote_style = None } + ); + + // Char literal (auto-wrapped in Some) + assert_eq!( + BacktickGenericDialect::new().identifier_quote_style("x"), + Some('`') + ); + // Another char literal + assert_eq!( + AnotherBacktickDialect::new().identifier_quote_style("x"), + Some('[') + ); + // None (overrides PostgreSQL's default '"') + assert_eq!( + QuotelessPostgreSqlDialect::new().identifier_quote_style("x"), + None + ); +} From ee3378e917e34834a8b37c48987d3b7d60e03b7f Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Fri, 23 Jan 2026 15:14:36 +0400 Subject: [PATCH 2/2] Resolved some whitespace lint --- src/dialect/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 56acf4ef3f..215ed8e4e8 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -55,7 +55,7 @@ pub use self::sqlite::SQLiteDialect; /// Macro for streamlining the creation of derived `Dialect` objects. /// The generated struct includes `new()` and `default()` constructors. /// Requires the `derive-dialect` feature. -/// +/// /// # Syntax /// /// ```text