From caa08d30f240c07f2b6fd08c6ffb9ae28d187f09 Mon Sep 17 00:00:00 2001 From: Manos Pitsidianakis Date: Thu, 3 Jul 2025 16:33:43 +0300 Subject: rust/qemu-api-macros: use syn::Error directly Our MacroError type wraps syn::Error as a variant, and uses another variant for custom errors. Fortunately syn::Error can be used directly, avoiding extra code on our side, so change the proc macro crate to use it. Signed-off-by: Manos Pitsidianakis Link: https://lore.kernel.org/r/20250703-rust_macros-v1-1-b99f82febbbf@linaro.org Signed-off-by: Paolo Bonzini --- rust/qemu-api-macros/src/lib.rs | 86 ++++++++++++++++++++--------------------- 1 file changed, 43 insertions(+), 43 deletions(-) (limited to 'rust/qemu-api-macros/src/lib.rs') diff --git a/rust/qemu-api-macros/src/lib.rs b/rust/qemu-api-macros/src/lib.rs index c18bb4e036..2cb79c799a 100644 --- a/rust/qemu-api-macros/src/lib.rs +++ b/rust/qemu-api-macros/src/lib.rs @@ -6,83 +6,79 @@ use proc_macro::TokenStream; use quote::quote; use syn::{ parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma, Data, - DeriveInput, Field, Fields, FieldsUnnamed, Ident, Meta, Path, Token, Variant, + DeriveInput, Error, Field, Fields, FieldsUnnamed, Ident, Meta, Path, Token, Variant, }; - -mod utils; -use utils::MacroError; - mod bits; use bits::BitsConstInternal; fn get_fields<'a>( input: &'a DeriveInput, msg: &str, -) -> Result<&'a Punctuated, MacroError> { +) -> Result<&'a Punctuated, Error> { let Data::Struct(ref s) = &input.data else { - return Err(MacroError::Message( - format!("Struct required for {msg}"), + return Err(Error::new( input.ident.span(), + format!("Struct required for {msg}"), )); }; let Fields::Named(ref fs) = &s.fields else { - return Err(MacroError::Message( - format!("Named fields required for {msg}"), + return Err(Error::new( input.ident.span(), + format!("Named fields required for {msg}"), )); }; Ok(&fs.named) } -fn get_unnamed_field<'a>(input: &'a DeriveInput, msg: &str) -> Result<&'a Field, MacroError> { +fn get_unnamed_field<'a>(input: &'a DeriveInput, msg: &str) -> Result<&'a Field, Error> { let Data::Struct(ref s) = &input.data else { - return Err(MacroError::Message( - format!("Struct required for {msg}"), + return Err(Error::new( input.ident.span(), + format!("Struct required for {msg}"), )); }; let Fields::Unnamed(FieldsUnnamed { ref unnamed, .. }) = &s.fields else { - return Err(MacroError::Message( - format!("Tuple struct required for {msg}"), + return Err(Error::new( s.fields.span(), + format!("Tuple struct required for {msg}"), )); }; if unnamed.len() != 1 { - return Err(MacroError::Message( - format!("A single field is required for {msg}"), + return Err(Error::new( s.fields.span(), + format!("A single field is required for {msg}"), )); } Ok(&unnamed[0]) } -fn is_c_repr(input: &DeriveInput, msg: &str) -> Result<(), MacroError> { +fn is_c_repr(input: &DeriveInput, msg: &str) -> Result<(), Error> { let expected = parse_quote! { #[repr(C)] }; if input.attrs.iter().any(|attr| attr == &expected) { Ok(()) } else { - Err(MacroError::Message( - format!("#[repr(C)] required for {msg}"), + Err(Error::new( input.ident.span(), + format!("#[repr(C)] required for {msg}"), )) } } -fn is_transparent_repr(input: &DeriveInput, msg: &str) -> Result<(), MacroError> { +fn is_transparent_repr(input: &DeriveInput, msg: &str) -> Result<(), Error> { let expected = parse_quote! { #[repr(transparent)] }; if input.attrs.iter().any(|attr| attr == &expected) { Ok(()) } else { - Err(MacroError::Message( - format!("#[repr(transparent)] required for {msg}"), + Err(Error::new( input.ident.span(), + format!("#[repr(transparent)] required for {msg}"), )) } } -fn derive_object_or_error(input: DeriveInput) -> Result { +fn derive_object_or_error(input: DeriveInput) -> Result { is_c_repr(&input, "#[derive(Object)]")?; let name = &input.ident; @@ -103,12 +99,13 @@ fn derive_object_or_error(input: DeriveInput) -> Result TokenStream { let input = parse_macro_input!(input as DeriveInput); - let expanded = derive_object_or_error(input).unwrap_or_else(Into::into); - TokenStream::from(expanded) + derive_object_or_error(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() } -fn derive_opaque_or_error(input: DeriveInput) -> Result { +fn derive_opaque_or_error(input: DeriveInput) -> Result { is_transparent_repr(&input, "#[derive(Wrapper)]")?; let name = &input.ident; @@ -149,13 +146,14 @@ fn derive_opaque_or_error(input: DeriveInput) -> Result TokenStream { let input = parse_macro_input!(input as DeriveInput); - let expanded = derive_opaque_or_error(input).unwrap_or_else(Into::into); - TokenStream::from(expanded) + derive_opaque_or_error(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() } #[allow(non_snake_case)] -fn get_repr_uN(input: &DeriveInput, msg: &str) -> Result { +fn get_repr_uN(input: &DeriveInput, msg: &str) -> Result { let repr = input.attrs.iter().find(|attr| attr.path().is_ident("repr")); if let Some(repr) = repr { let nested = repr.parse_args_with(Punctuated::::parse_terminated)?; @@ -170,23 +168,23 @@ fn get_repr_uN(input: &DeriveInput, msg: &str) -> Result { } } - Err(MacroError::Message( - format!("#[repr(u8/u16/u32/u64) required for {msg}"), + Err(Error::new( input.ident.span(), + format!("#[repr(u8/u16/u32/u64) required for {msg}"), )) } -fn get_variants(input: &DeriveInput) -> Result<&Punctuated, MacroError> { +fn get_variants(input: &DeriveInput) -> Result<&Punctuated, Error> { let Data::Enum(ref e) = &input.data else { - return Err(MacroError::Message( - "Cannot derive TryInto for union or struct.".to_string(), + return Err(Error::new( input.ident.span(), + "Cannot derive TryInto for union or struct.", )); }; if let Some(v) = e.variants.iter().find(|v| v.fields != Fields::Unit) { - return Err(MacroError::Message( - "Cannot derive TryInto for enum with non-unit variants.".to_string(), + return Err(Error::new( v.fields.span(), + "Cannot derive TryInto for enum with non-unit variants.", )); } Ok(&e.variants) @@ -197,7 +195,7 @@ fn derive_tryinto_body( name: &Ident, variants: &Punctuated, repr: &Path, -) -> Result { +) -> Result { let discriminants: Vec<&Ident> = variants.iter().map(|f| &f.ident).collect(); Ok(quote! { @@ -210,7 +208,7 @@ fn derive_tryinto_body( } #[rustfmt::skip::macros(quote)] -fn derive_tryinto_or_error(input: DeriveInput) -> Result { +fn derive_tryinto_or_error(input: DeriveInput) -> Result { let repr = get_repr_uN(&input, "#[derive(TryInto)]")?; let name = &input.ident; let body = derive_tryinto_body(name, get_variants(&input)?, &repr)?; @@ -247,9 +245,10 @@ fn derive_tryinto_or_error(input: DeriveInput) -> Result TokenStream { let input = parse_macro_input!(input as DeriveInput); - let expanded = derive_tryinto_or_error(input).unwrap_or_else(Into::into); - TokenStream::from(expanded) + derive_tryinto_or_error(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() } #[proc_macro] @@ -257,6 +256,7 @@ pub fn bits_const_internal(ts: TokenStream) -> TokenStream { let ts = proc_macro2::TokenStream::from(ts); let mut it = ts.into_iter(); - let expanded = BitsConstInternal::parse(&mut it).unwrap_or_else(Into::into); - TokenStream::from(expanded) + BitsConstInternal::parse(&mut it) + .unwrap_or_else(syn::Error::into_compile_error) + .into() } -- cgit 1.4.1 From c3a08c8dcbe568d9e7f8a66d300a668bcb8673c0 Mon Sep 17 00:00:00 2001 From: Manos Pitsidianakis Date: Fri, 4 Jul 2025 13:26:57 +0300 Subject: rust/qemu-api-macros: normalize TryInto output Remove extraneous `;` and add missing trailing comma to TryInto derive macro to match rustfmt style. We will add a test in the followup commit and we would like the inlined output in the test body to be properly formatted as well. No functional changes intended. Signed-off-by: Manos Pitsidianakis Reviewed-by: Zhao Liu Link: https://lore.kernel.org/r/20250704-rust_add_derive_macro_unit_tests-v1-1-ebd47fa7f78f@linaro.org Signed-off-by: Paolo Bonzini --- rust/qemu-api-macros/src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'rust/qemu-api-macros/src/lib.rs') diff --git a/rust/qemu-api-macros/src/lib.rs b/rust/qemu-api-macros/src/lib.rs index 2cb79c799a..5bbf8c6127 100644 --- a/rust/qemu-api-macros/src/lib.rs +++ b/rust/qemu-api-macros/src/lib.rs @@ -199,7 +199,7 @@ fn derive_tryinto_body( let discriminants: Vec<&Ident> = variants.iter().map(|f| &f.ident).collect(); Ok(quote! { - #(const #discriminants: #repr = #name::#discriminants as #repr;)*; + #(const #discriminants: #repr = #name::#discriminants as #repr;)* match value { #(#discriminants => core::result::Result::Ok(#name::#discriminants),)* _ => core::result::Result::Err(value), @@ -227,7 +227,7 @@ fn derive_tryinto_or_error(input: DeriveInput) -> Result x, - Err(_) => panic!(#errmsg) + Err(_) => panic!(#errmsg), } } } -- cgit 1.4.1 From a721d9a9f3dd7bb9d6ed81ea1a11a1157755741c Mon Sep 17 00:00:00 2001 From: Manos Pitsidianakis Date: Fri, 4 Jul 2025 13:26:58 +0300 Subject: rust/qemu-api-macros: add unit tests Add unit tests to check Derive macro output for expected error messages, or for expected correct codegen output. Signed-off-by: Manos Pitsidianakis Reviewed-by: Zhao Liu Link: https://lore.kernel.org/r/20250704-rust_add_derive_macro_unit_tests-v1-2-ebd47fa7f78f@linaro.org [Remove usage of MacroError. - Paolo] Signed-off-by: Paolo Bonzini --- rust/qemu-api-macros/meson.build | 3 + rust/qemu-api-macros/src/lib.rs | 3 + rust/qemu-api-macros/src/tests.rs | 137 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 143 insertions(+) create mode 100644 rust/qemu-api-macros/src/tests.rs (limited to 'rust/qemu-api-macros/src/lib.rs') diff --git a/rust/qemu-api-macros/meson.build b/rust/qemu-api-macros/meson.build index 8610ce1c84..2152bcb99b 100644 --- a/rust/qemu-api-macros/meson.build +++ b/rust/qemu-api-macros/meson.build @@ -17,3 +17,6 @@ _qemu_api_macros_rs = rust.proc_macro( qemu_api_macros = declare_dependency( link_with: _qemu_api_macros_rs, ) + +rust.test('rust-qemu-api-macros-tests', _qemu_api_macros_rs, + suite: ['unit', 'rust']) diff --git a/rust/qemu-api-macros/src/lib.rs b/rust/qemu-api-macros/src/lib.rs index 5bbf8c6127..b525d89c09 100644 --- a/rust/qemu-api-macros/src/lib.rs +++ b/rust/qemu-api-macros/src/lib.rs @@ -11,6 +11,9 @@ use syn::{ mod bits; use bits::BitsConstInternal; +#[cfg(test)] +mod tests; + fn get_fields<'a>( input: &'a DeriveInput, msg: &str, diff --git a/rust/qemu-api-macros/src/tests.rs b/rust/qemu-api-macros/src/tests.rs new file mode 100644 index 0000000000..d6dcd62fcf --- /dev/null +++ b/rust/qemu-api-macros/src/tests.rs @@ -0,0 +1,137 @@ +// Copyright 2025, Linaro Limited +// Author(s): Manos Pitsidianakis +// SPDX-License-Identifier: GPL-2.0-or-later + +use quote::quote; + +use super::*; + +macro_rules! derive_compile_fail { + ($derive_fn:ident, $input:expr, $error_msg:expr) => {{ + let input: proc_macro2::TokenStream = $input; + let error_msg: &str = $error_msg; + let derive_fn: fn(input: syn::DeriveInput) -> Result = + $derive_fn; + + let input: syn::DeriveInput = syn::parse2(input).unwrap(); + let result = derive_fn(input); + let err = result.unwrap_err().into_compile_error(); + assert_eq!( + err.to_string(), + quote! { ::core::compile_error! { #error_msg } }.to_string() + ); + }}; +} + +macro_rules! derive_compile { + ($derive_fn:ident, $input:expr, $($expected:tt)*) => {{ + let input: proc_macro2::TokenStream = $input; + let expected: proc_macro2::TokenStream = $($expected)*; + let derive_fn: fn(input: syn::DeriveInput) -> Result = + $derive_fn; + + let input: syn::DeriveInput = syn::parse2(input).unwrap(); + let result = derive_fn(input).unwrap(); + assert_eq!(result.to_string(), expected.to_string()); + }}; +} + +#[test] +fn test_derive_object() { + derive_compile_fail!( + derive_object_or_error, + quote! { + #[derive(Object)] + struct Foo { + _unused: [u8; 0], + } + }, + "#[repr(C)] required for #[derive(Object)]" + ); + derive_compile!( + derive_object_or_error, + quote! { + #[derive(Object)] + #[repr(C)] + struct Foo { + _unused: [u8; 0], + } + }, + quote! { + ::qemu_api::assert_field_type!( + Foo, + _unused, + ::qemu_api::qom::ParentField<::ParentType> + ); + ::qemu_api::module_init! { + MODULE_INIT_QOM => unsafe { + ::qemu_api::bindings::type_register_static(&::TYPE_INFO); + } + } + } + ); +} + +#[test] +fn test_derive_tryinto() { + derive_compile_fail!( + derive_tryinto_or_error, + quote! { + #[derive(TryInto)] + struct Foo { + _unused: [u8; 0], + } + }, + "#[repr(u8/u16/u32/u64) required for #[derive(TryInto)]" + ); + derive_compile!( + derive_tryinto_or_error, + quote! { + #[derive(TryInto)] + #[repr(u8)] + enum Foo { + First = 0, + Second, + } + }, + quote! { + impl Foo { + #[allow(dead_code)] + pub const fn into_bits(self) -> u8 { + self as u8 + } + + #[allow(dead_code)] + pub const fn from_bits(value: u8) -> Self { + match ({ + const First: u8 = Foo::First as u8; + const Second: u8 = Foo::Second as u8; + match value { + First => core::result::Result::Ok(Foo::First), + Second => core::result::Result::Ok(Foo::Second), + _ => core::result::Result::Err(value), + } + }) { + Ok(x) => x, + Err(_) => panic!("invalid value for Foo"), + } + } + } + + impl core::convert::TryFrom for Foo { + type Error = u8; + + #[allow(ambiguous_associated_items)] + fn try_from(value: u8) -> Result { + const First: u8 = Foo::First as u8; + const Second: u8 = Foo::Second as u8; + match value { + First => core::result::Result::Ok(Foo::First), + Second => core::result::Result::Ok(Foo::Second), + _ => core::result::Result::Err(value), + } + } + } + } + ); +} -- cgit 1.4.1