|
18 | 18 | use proc_macro2::TokenStream; |
19 | 19 | use quote::{format_ident, quote, quote_spanned, ToTokens}; |
20 | 20 | use syn::spanned::Spanned; |
21 | | -use syn::{ |
22 | | - parse::{Parse, ParseStream}, |
23 | | - parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics, |
24 | | - Ident, Index, LitStr, Meta, Token, |
25 | | -}; |
| 21 | +use syn::{parse::{Parse, ParseStream}, parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics, Ident, Index, LitStr, Meta, Token, Type, TypePath}; |
| 22 | +use syn::{Path, PathArguments}; |
26 | 23 |
|
27 | 24 | /// Implementation of `[#derive(Visit)]` |
28 | 25 | #[proc_macro_derive(VisitMut, attributes(visit))] |
@@ -182,9 +179,21 @@ fn visit_children( |
182 | 179 | Fields::Named(fields) => { |
183 | 180 | let recurse = fields.named.iter().map(|f| { |
184 | 181 | let name = &f.ident; |
| 182 | + let is_option = is_option(&f.ty); |
185 | 183 | let attributes = Attributes::parse(&f.attrs); |
186 | | - let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name)); |
187 | | - quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit) |
| 184 | + if is_option && attributes.with.is_some() { |
| 185 | + let (pre_visit, post_visit) = attributes.visit(quote!(value)); |
| 186 | + quote_spanned!(f.span() => |
| 187 | + if let Some(value) = &#modifier self.#name { |
| 188 | + #pre_visit sqlparser::ast::#visit_trait::visit(value, visitor)?; #post_visit |
| 189 | + } |
| 190 | + ) |
| 191 | + } else { |
| 192 | + let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name)); |
| 193 | + quote_spanned!(f.span() => |
| 194 | + #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit |
| 195 | + ) |
| 196 | + } |
188 | 197 | }); |
189 | 198 | quote! { |
190 | 199 | #(#recurse)* |
@@ -256,3 +265,16 @@ fn visit_children( |
256 | 265 | Data::Union(_) => unimplemented!(), |
257 | 266 | } |
258 | 267 | } |
| 268 | + |
| 269 | +fn is_option(ty: &Type) -> bool { |
| 270 | + if let Type::Path(TypePath { path: Path { segments, .. }, .. }) = ty { |
| 271 | + if let Some(segment) = segments.last() { |
| 272 | + if segment.ident == "Option" { |
| 273 | + if let PathArguments::AngleBracketed(args) = &segment.arguments { |
| 274 | + return args.args.len() == 1; |
| 275 | + } |
| 276 | + } |
| 277 | + } |
| 278 | + } |
| 279 | + false |
| 280 | +} |
0 commit comments