diff --git a/derive/Cargo.toml b/derive/Cargo.toml index c15d47ad58..91fdc150e9 100644 --- a/derive/Cargo.toml +++ b/derive/Cargo.toml @@ -8,6 +8,6 @@ edition = "2018" proc-macro = true [dependencies] -syn = "0.15.29" +syn = { version = "0.15.29", features = ["full"] } quote = "0.6.11" proc-macro2 = "0.4.27" diff --git a/derive/src/from_args.rs b/derive/src/from_args.rs new file mode 100644 index 0000000000..7845b28800 --- /dev/null +++ b/derive/src/from_args.rs @@ -0,0 +1,202 @@ +use super::rustpython_path_derive; +use proc_macro2::TokenStream as TokenStream2; +use quote::quote; +use syn::{Attribute, Data, DeriveInput, Expr, Field, Fields, Ident, Lit, Meta, NestedMeta}; + +/// The kind of the python parameter, this corresponds to the value of Parameter.kind +/// (https://docs.python.org/3/library/inspect.html#inspect.Parameter.kind) +enum ParameterKind { + PositionalOnly, + PositionalOrKeyword, + KeywordOnly, +} + +impl ParameterKind { + fn from_ident(ident: &Ident) -> ParameterKind { + if ident == "positional_only" { + ParameterKind::PositionalOnly + } else if ident == "positional_or_keyword" { + ParameterKind::PositionalOrKeyword + } else if ident == "keyword_only" { + ParameterKind::KeywordOnly + } else { + panic!("Unrecognised attribute") + } + } +} + +struct ArgAttribute { + kind: ParameterKind, + default: Option, + optional: bool, +} + +impl ArgAttribute { + fn from_attribute(attr: &Attribute) -> Option { + if !attr.path.is_ident("pyarg") { + return None; + } + + match attr.parse_meta().unwrap() { + Meta::List(list) => { + let mut iter = list.nested.iter(); + let first_arg = iter.next().expect("at least one argument in pyarg list"); + let kind = match first_arg { + NestedMeta::Meta(Meta::Word(ident)) => ParameterKind::from_ident(ident), + _ => panic!("Bad syntax for first pyarg attribute argument"), + }; + + let mut attribute = ArgAttribute { + kind, + default: None, + optional: false, + }; + + while let Some(arg) = iter.next() { + attribute.parse_argument(arg); + } + + assert!( + attribute.default.is_none() || !attribute.optional, + "Can't set both a default value and optional" + ); + + Some(attribute) + } + _ => panic!("Bad syntax for pyarg attribute"), + } + } + + fn parse_argument(&mut self, arg: &NestedMeta) { + match arg { + NestedMeta::Meta(Meta::Word(ident)) => { + if ident == "default" { + assert!(self.default.is_none(), "Default already set"); + let expr = syn::parse_str::("Default::default()").unwrap(); + self.default = Some(expr); + } else if ident == "optional" { + self.optional = true; + } else { + panic!("Unrecognised pyarg attribute '{}'", ident); + } + } + NestedMeta::Meta(Meta::NameValue(name_value)) => { + if name_value.ident == "default" { + assert!(self.default.is_none(), "Default already set"); + + match name_value.lit { + Lit::Str(ref val) => { + let expr = val + .parse::() + .expect("a valid expression for default argument"); + self.default = Some(expr); + } + _ => panic!("Expected string value for default argument"), + } + } else if name_value.ident == "optional" { + match name_value.lit { + Lit::Bool(ref val) => { + self.optional = val.value; + } + _ => panic!("Expected boolean value for optional argument"), + } + } else { + panic!("Unrecognised pyarg attribute '{}'", name_value.ident); + } + } + _ => panic!("Bad syntax for first pyarg attribute argument"), + }; + } +} + +fn generate_field(field: &Field) -> TokenStream2 { + let mut pyarg_attrs = field + .attrs + .iter() + .filter_map(ArgAttribute::from_attribute) + .collect::>(); + let attr = if pyarg_attrs.is_empty() { + ArgAttribute { + kind: ParameterKind::PositionalOrKeyword, + default: None, + optional: false, + } + } else if pyarg_attrs.len() == 1 { + pyarg_attrs.remove(0) + } else { + panic!( + "Multiple pyarg attributes on field '{}'", + field.ident.as_ref().unwrap() + ); + }; + + let name = &field.ident; + let middle = quote! { + .map(|x| crate::pyobject::TryFromObject::try_from_object(vm, x)).transpose()? + }; + let ending = if let Some(default) = attr.default { + quote! { + .unwrap_or_else(|| #default) + } + } else if attr.optional { + quote! { + .map(crate::function::OptionalArg::Present) + .unwrap_or(crate::function::OptionalArg::Missing) + } + } else { + let err = match attr.kind { + ParameterKind::PositionalOnly | ParameterKind::PositionalOrKeyword => quote! { + crate::function::ArgumentError::TooFewArgs + }, + ParameterKind::KeywordOnly => quote! { + crate::function::ArgumentError::RequiredKeywordArgument(tringify!(#name)) + }, + }; + quote! { + .ok_or_else(|| #err)? + } + }; + + match attr.kind { + ParameterKind::PositionalOnly => { + quote! { + #name: args.take_positional()#middle#ending, + } + } + ParameterKind::PositionalOrKeyword => { + quote! { + #name: args.take_positional_keyword(stringify!(#name))#middle#ending, + } + } + ParameterKind::KeywordOnly => { + quote! { + #name: args.take_keyword(stringify!(#name))#middle#ending, + } + } + } +} + +pub fn impl_from_args(input: DeriveInput) -> TokenStream2 { + let rp_path = rustpython_path_derive(&input); + let fields = match input.data { + Data::Struct(ref data) => { + match data.fields { + Fields::Named(ref fields) => fields.named.iter().map(generate_field), + Fields::Unnamed(_) | Fields::Unit => unimplemented!(), // TODO: better error message + } + } + Data::Enum(_) | Data::Union(_) => unimplemented!(), // TODO: better error message + }; + + let name = &input.ident; + quote! { + impl #rp_path::function::FromArgs for #name { + fn from_args( + vm: &crate::vm::VirtualMachine, + args: &mut crate::function::PyFuncArgs + ) -> Result { + Ok(#name { #(#fields)* }) + } + } + } +} diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 8f15f0e876..adde078f03 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -1,215 +1,63 @@ extern crate proc_macro; use proc_macro::TokenStream; -use proc_macro2::TokenStream as TokenStream2; use quote::quote; -use syn::{Attribute, Data, DeriveInput, Expr, Field, Fields, Ident, Lit, Meta, NestedMeta}; +use syn::{parse_macro_input, AttributeArgs, DeriveInput, Item}; -#[proc_macro_derive(FromArgs, attributes(pyarg))] -pub fn derive_from_args(input: TokenStream) -> TokenStream { - let ast: DeriveInput = syn::parse(input).unwrap(); +mod from_args; +mod pyclass; - let gen = impl_from_args(&ast); - gen.to_string().parse().unwrap() +fn rustpython_path(inside_vm: bool) -> syn::Path { + let path = if inside_vm { + quote!(crate) + } else { + quote!(::rustpython_vm) + }; + syn::parse2(path).unwrap() } -/// The kind of the python parameter, this corresponds to the value of Parameter.kind -/// (https://docs.python.org/3/library/inspect.html#inspect.Parameter.kind) -enum ParameterKind { - PositionalOnly, - PositionalOrKeyword, - KeywordOnly, +/// Does the item have the #[__inside_vm] attribute on it, signifying that the derive target is +/// being derived from inside the `rustpython_vm` crate. +fn rustpython_path_derive(input: &DeriveInput) -> syn::Path { + rustpython_path( + input + .attrs + .iter() + .any(|attr| attr.path.is_ident("__inside_vm")), + ) } -impl ParameterKind { - fn from_ident(ident: &Ident) -> ParameterKind { - if ident == "positional_only" { - ParameterKind::PositionalOnly - } else if ident == "positional_or_keyword" { - ParameterKind::PositionalOrKeyword - } else if ident == "keyword_only" { - ParameterKind::KeywordOnly +fn rustpython_path_attr(attr: &AttributeArgs) -> syn::Path { + rustpython_path(attr.iter().any(|meta| { + if let syn::NestedMeta::Meta(meta) = meta { + if let syn::Meta::Word(ident) = meta { + ident == "__inside_vm" + } else { + false + } } else { - panic!("Unrecognised attribute") + false } - } + })) } -struct ArgAttribute { - kind: ParameterKind, - default: Option, - optional: bool, -} - -impl ArgAttribute { - fn from_attribute(attr: &Attribute) -> Option { - if !attr.path.is_ident("pyarg") { - return None; - } - - match attr.parse_meta().unwrap() { - Meta::List(list) => { - let mut iter = list.nested.iter(); - let first_arg = iter.next().expect("at least one argument in pyarg list"); - let kind = match first_arg { - NestedMeta::Meta(Meta::Word(ident)) => ParameterKind::from_ident(ident), - _ => panic!("Bad syntax for first pyarg attribute argument"), - }; - - let mut attribute = ArgAttribute { - kind, - default: None, - optional: false, - }; - - while let Some(arg) = iter.next() { - attribute.parse_argument(arg); - } - - assert!( - attribute.default.is_none() || !attribute.optional, - "Can't set both a default value and optional" - ); - - Some(attribute) - } - _ => panic!("Bad syntax for pyarg attribute"), - } - } - - fn parse_argument(&mut self, arg: &NestedMeta) { - match arg { - NestedMeta::Meta(Meta::Word(ident)) => { - if ident == "default" { - assert!(self.default.is_none(), "Default already set"); - let expr = syn::parse_str::("Default::default()").unwrap(); - self.default = Some(expr); - } else if ident == "optional" { - self.optional = true; - } else { - panic!("Unrecognised pyarg attribute '{}'", ident); - } - } - NestedMeta::Meta(Meta::NameValue(name_value)) => { - if name_value.ident == "default" { - assert!(self.default.is_none(), "Default already set"); +#[proc_macro_derive(FromArgs, attributes(__inside_vm, pyarg))] +pub fn derive_from_args(input: TokenStream) -> TokenStream { + let ast: DeriveInput = syn::parse(input).unwrap(); - match name_value.lit { - Lit::Str(ref val) => { - let expr = val - .parse::() - .expect("a valid expression for default argument"); - self.default = Some(expr); - } - _ => panic!("Expected string value for default argument"), - } - } else if name_value.ident == "optional" { - match name_value.lit { - Lit::Bool(ref val) => { - self.optional = val.value; - } - _ => panic!("Expected boolean value for optional argument"), - } - } else { - panic!("Unrecognised pyarg attribute '{}'", name_value.ident); - } - } - _ => panic!("Bad syntax for first pyarg attribute argument"), - }; - } + from_args::impl_from_args(ast).into() } -fn generate_field(field: &Field) -> TokenStream2 { - let mut pyarg_attrs = field - .attrs - .iter() - .filter_map(ArgAttribute::from_attribute) - .collect::>(); - let attr = if pyarg_attrs.is_empty() { - ArgAttribute { - kind: ParameterKind::PositionalOrKeyword, - default: None, - optional: false, - } - } else if pyarg_attrs.len() == 1 { - pyarg_attrs.remove(0) - } else { - panic!( - "Multiple pyarg attributes on field '{}'", - field.ident.as_ref().unwrap() - ); - }; - - let name = &field.ident; - let middle = quote! { - .map(|x| crate::pyobject::TryFromObject::try_from_object(vm, x)).transpose()? - }; - let ending = if let Some(default) = attr.default { - quote! { - .unwrap_or_else(|| #default) - } - } else if attr.optional { - quote! { - .map(crate::function::OptionalArg::Present) - .unwrap_or(crate::function::OptionalArg::Missing) - } - } else { - let err = match attr.kind { - ParameterKind::PositionalOnly | ParameterKind::PositionalOrKeyword => quote! { - crate::function::ArgumentError::TooFewArgs - }, - ParameterKind::KeywordOnly => quote! { - crate::function::ArgumentError::RequiredKeywordArgument(tringify!(#name)) - }, - }; - quote! { - .ok_or_else(|| #err)? - } - }; - - match attr.kind { - ParameterKind::PositionalOnly => { - quote! { - #name: args.take_positional()#middle#ending, - } - } - ParameterKind::PositionalOrKeyword => { - quote! { - #name: args.take_positional_keyword(stringify!(#name))#middle#ending, - } - } - ParameterKind::KeywordOnly => { - quote! { - #name: args.take_keyword(stringify!(#name))#middle#ending, - } - } - } +#[proc_macro_attribute] +pub fn pyclass(attr: TokenStream, item: TokenStream) -> TokenStream { + let attr = parse_macro_input!(attr as AttributeArgs); + let item = parse_macro_input!(item as Item); + pyclass::impl_pyclass(attr, item).into() } -fn impl_from_args(input: &DeriveInput) -> TokenStream2 { - // FIXME: This references types using `crate` instead of `rustpython_vm` - // so that it can be used in the latter. How can we support both? - // Can use extern crate self as rustpython_vm; once in stable. - // https://github.com/rust-lang/rust/issues/56409 - let fields = match input.data { - Data::Struct(ref data) => { - match data.fields { - Fields::Named(ref fields) => fields.named.iter().map(generate_field), - Fields::Unnamed(_) | Fields::Unit => unimplemented!(), // TODO: better error message - } - } - Data::Enum(_) | Data::Union(_) => unimplemented!(), // TODO: better error message - }; - - let name = &input.ident; - quote! { - impl crate::function::FromArgs for #name { - fn from_args( - vm: &crate::vm::VirtualMachine, - args: &mut crate::function::PyFuncArgs - ) -> Result { - Ok(#name { #(#fields)* }) - } - } - } +#[proc_macro_attribute] +pub fn pyimpl(attr: TokenStream, item: TokenStream) -> TokenStream { + let attr = parse_macro_input!(attr as AttributeArgs); + let item = parse_macro_input!(item as Item); + pyclass::impl_pyimpl(attr, item).into() } diff --git a/derive/src/pyclass.rs b/derive/src/pyclass.rs new file mode 100644 index 0000000000..ad4ebf66ea --- /dev/null +++ b/derive/src/pyclass.rs @@ -0,0 +1,196 @@ +use super::rustpython_path_attr; +use proc_macro2::{Span, TokenStream as TokenStream2}; +use quote::quote; +use syn::{Attribute, AttributeArgs, Ident, ImplItem, Item, Lit, Meta, MethodSig, NestedMeta}; + +enum MethodKind { + Method, + Property, +} + +impl MethodKind { + fn to_ctx_constructor_fn(&self) -> Ident { + let f = match self { + MethodKind::Method => "new_rustfunc", + MethodKind::Property => "new_property", + }; + Ident::new(f, Span::call_site()) + } +} + +struct Method { + fn_name: Ident, + py_name: String, + kind: MethodKind, +} + +impl Method { + fn from_syn(attrs: &mut Vec, sig: &MethodSig) -> Option { + let mut py_name = None; + let mut kind = MethodKind::Method; + let mut pymethod_to_remove = Vec::new(); + let metas = attrs + .iter() + .enumerate() + .filter_map(|(i, attr)| { + if attr.path.is_ident("pymethod") { + let meta = attr.parse_meta().expect("Invalid attribute"); + // remove #[pymethod] because there's no actual proc macro + // implementation for it + pymethod_to_remove.push(i); + match meta { + Meta::List(list) => Some(list), + Meta::Word(_) => None, + Meta::NameValue(_) => panic!( + "#[pymethod = ...] attribute on a method should be a list, like \ + #[pymethod(...)]" + ), + } + } else { + None + } + }) + .flat_map(|attr| attr.nested); + for meta in metas { + if let NestedMeta::Meta(meta) = meta { + match meta { + Meta::NameValue(name_value) => { + if name_value.ident == "name" { + if let Lit::Str(s) = &name_value.lit { + py_name = Some(s.value()); + } else { + panic!("#[pymethod(name = ...)] must be a string"); + } + } + } + Meta::Word(ident) => { + if ident == "property" { + kind = MethodKind::Property + } + } + _ => {} + } + } + } + // if there are no #[pymethods]s, then it's not a method, so exclude it from + // the final result + if pymethod_to_remove.is_empty() { + return None; + } + for i in pymethod_to_remove { + attrs.remove(i); + } + let py_name = py_name.unwrap_or_else(|| sig.ident.to_string()); + Some(Method { + fn_name: sig.ident.clone(), + py_name, + kind, + }) + } +} + +pub fn impl_pyimpl(attr: AttributeArgs, item: Item) -> TokenStream2 { + let mut imp = if let Item::Impl(imp) = item { + imp + } else { + return quote!(#item); + }; + + let rp_path = rustpython_path_attr(&attr); + + let methods = imp + .items + .iter_mut() + .filter_map(|item| { + if let ImplItem::Method(meth) = item { + Method::from_syn(&mut meth.attrs, &meth.sig) + } else { + None + } + }) + .collect::>(); + let ty = &imp.self_ty; + let methods = methods.iter().map( + |Method { + py_name, + fn_name, + kind, + }| { + let constructor_fn = kind.to_ctx_constructor_fn(); + quote! { + ctx.set_attr(class, #py_name, ctx.#constructor_fn(Self::#fn_name)); + } + }, + ); + + quote! { + #imp + impl #rp_path::pyobject::PyClassImpl for #ty { + fn impl_extend_class( + ctx: &#rp_path::pyobject::PyContext, + class: &#rp_path::obj::objtype::PyClassRef, + ) { + #(#methods)* + } + } + } +} + +pub fn impl_pyclass(attr: AttributeArgs, item: Item) -> TokenStream2 { + let (item, ident, attrs) = match item { + Item::Struct(struc) => (quote!(#struc), struc.ident, struc.attrs), + Item::Enum(enu) => (quote!(#enu), enu.ident, enu.attrs), + _ => panic!("#[pyclass] can only be on a struct or enum declaration"), + }; + + let rp_path = rustpython_path_attr(&attr); + + let mut class_name = None; + for attr in attr { + if let NestedMeta::Meta(meta) = attr { + if let Meta::NameValue(name_value) = meta { + if name_value.ident == "name" { + if let Lit::Str(s) = name_value.lit { + class_name = Some(s.value()); + } else { + panic!("#[pyclass(name = ...)] must be a string"); + } + } + } + } + } + let class_name = class_name.unwrap_or_else(|| ident.to_string()); + + let mut doc: Option> = None; + for attr in attrs.iter() { + if attr.path.is_ident("doc") { + let meta = attr.parse_meta().expect("expected doc attr to be a meta"); + if let Meta::NameValue(name_value) = meta { + if let Lit::Str(s) = name_value.lit { + let val = s.value().trim().to_string(); + match doc { + Some(ref mut doc) => doc.push(val), + None => doc = Some(vec![val]), + } + } else { + panic!("expected #[doc = ...] to be a string") + } + } + } + } + let doc = match doc { + Some(doc) => { + let doc = doc.join("\n"); + quote!(Some(#doc)) + } + None => quote!(None), + }; + + quote! { + #item + impl #rp_path::pyobject::PyClassDef for #ident { + const NAME: &'static str = #class_name; + const DOC: Option<&'static str> = #doc; + } + } +} diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index ed656c0fbf..205d0c66e0 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -536,6 +536,7 @@ fn builtin_pow(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { } #[derive(Debug, FromArgs)] +#[__inside_vm] pub struct PrintOptions { #[pyarg(keyword_only, default = "None")] sep: Option, diff --git a/vm/src/lib.rs b/vm/src/lib.rs index 3892fe2a45..4a300dceff 100644 --- a/vm/src/lib.rs +++ b/vm/src/lib.rs @@ -32,6 +32,8 @@ extern crate rustpython_parser; #[macro_use] extern crate rustpython_derive; +pub use rustpython_derive::*; + //extern crate eval; use eval::eval::*; // use py_code_object::{Function, NativeType, PyCodeObject}; diff --git a/vm/src/obj/objint.rs b/vm/src/obj/objint.rs index 08175dc07f..b0eb0a8e05 100644 --- a/vm/src/obj/objint.rs +++ b/vm/src/obj/objint.rs @@ -373,6 +373,7 @@ impl PyIntRef { } #[derive(FromArgs)] +#[__inside_vm] struct IntOptions { #[pyarg(positional_only, optional = true)] val_options: OptionalArg, diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index aa62812a24..922306a9e3 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -10,8 +10,8 @@ use unicode_segmentation::UnicodeSegmentation; use crate::format::{FormatParseError, FormatPart, FormatString}; use crate::function::{OptionalArg, PyFuncArgs}; use crate::pyobject::{ - IdProtocol, IntoPyObject, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, - TryFromObject, TryIntoRef, TypeProtocol, + IdProtocol, IntoPyObject, PyClassImpl, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, + PyValue, TryFromObject, TryIntoRef, TypeProtocol, }; use crate::vm::VirtualMachine; @@ -20,6 +20,17 @@ use super::objsequence::PySliceableSequence; use super::objslice::PySlice; use super::objtype::{self, PyClassRef}; +/// str(object='') -> str +/// str(bytes_or_buffer[, encoding[, errors]]) -> str +/// +/// Create a new string object from the given object. If encoding or +/// errors is specified, then the object must expose a data buffer +/// that will be decoded using the given encoding and error handler. +/// Otherwise, returns the result of object.__str__() (if defined) +/// or repr(object). +/// encoding defaults to sys.getdefaultencoding(). +/// errors defaults to 'strict'." +#[pyclass(name = "str", __inside_vm)] #[derive(Clone, Debug)] pub struct PyString { // TODO: shouldn't be public @@ -48,7 +59,29 @@ impl TryIntoRef for &str { } } +#[pyimpl(__inside_vm)] impl PyStringRef { + // TODO: should with following format + // class str(object='') + // class str(object=b'', encoding='utf-8', errors='strict') + #[pymethod(name = "__new__")] + fn new( + cls: PyClassRef, + object: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + let string = match object { + OptionalArg::Present(ref input) => vm.to_str(input)?.into_object(), + OptionalArg::Missing => vm.new_str("".to_string()), + }; + if string.class().is(&cls) { + TryFromObject::try_from_object(vm, string) + } else { + let payload = string.payload::().unwrap(); + payload.clone().into_ref_with_type(vm, cls) + } + } + #[pymethod(name = "__add__")] fn add(self, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult { if objtype::isinstance(&rhs, &vm.ctx.str_type()) { Ok(format!("{}{}", self.value, get_value(&rhs))) @@ -57,10 +90,12 @@ impl PyStringRef { } } + #[pymethod(name = "__bool__")] fn bool(self, _vm: &VirtualMachine) -> bool { !self.value.is_empty() } + #[pymethod(name = "__eq__")] fn eq(self, rhs: PyObjectRef, vm: &VirtualMachine) -> bool { if objtype::isinstance(&rhs, &vm.ctx.str_type()) { self.value == get_value(&rhs) @@ -69,14 +104,17 @@ impl PyStringRef { } } + #[pymethod(name = "__contains__")] fn contains(self, needle: PyStringRef, _vm: &VirtualMachine) -> bool { self.value.contains(&needle.value) } + #[pymethod(name = "__getitem__")] fn getitem(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { subscript(vm, &self.value, needle) } + #[pymethod(name = "__gt__")] fn gt(self, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult { if objtype::isinstance(&rhs, &vm.ctx.str_type()) { Ok(self.value > get_value(&rhs)) @@ -85,6 +123,7 @@ impl PyStringRef { } } + #[pymethod(name = "__ge__")] fn ge(self, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult { if objtype::isinstance(&rhs, &vm.ctx.str_type()) { Ok(self.value >= get_value(&rhs)) @@ -93,6 +132,7 @@ impl PyStringRef { } } + #[pymethod(name = "__lt__")] fn lt(self, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult { if objtype::isinstance(&rhs, &vm.ctx.str_type()) { Ok(self.value < get_value(&rhs)) @@ -101,6 +141,7 @@ impl PyStringRef { } } + #[pymethod(name = "__le__")] fn le(self, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult { if objtype::isinstance(&rhs, &vm.ctx.str_type()) { Ok(self.value <= get_value(&rhs)) @@ -109,16 +150,19 @@ impl PyStringRef { } } + #[pymethod(name = "__hash__")] fn hash(self, _vm: &VirtualMachine) -> usize { let mut hasher = std::collections::hash_map::DefaultHasher::new(); self.value.hash(&mut hasher); hasher.finish() as usize } + #[pymethod(name = "__len__")] fn len(self, _vm: &VirtualMachine) -> usize { self.value.chars().count() } + #[pymethod(name = "__mul__")] fn mul(self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult { if objtype::isinstance(&val, &vm.ctx.int_type()) { let value = &self.value; @@ -133,10 +177,12 @@ impl PyStringRef { } } + #[pymethod(name = "__str__")] fn str(self, _vm: &VirtualMachine) -> PyStringRef { self } + #[pymethod(name = "__repr__")] fn repr(self, _vm: &VirtualMachine) -> String { let value = &self.value; let quote_char = if count_char(value, '\'') > count_char(value, '"') { @@ -167,27 +213,32 @@ impl PyStringRef { formatted } + #[pymethod] fn lower(self, _vm: &VirtualMachine) -> String { self.value.to_lowercase() } // casefold is much more aggressive than lower + #[pymethod] fn casefold(self, _vm: &VirtualMachine) -> String { caseless::default_case_fold_str(&self.value) } + #[pymethod] fn upper(self, _vm: &VirtualMachine) -> String { self.value.to_uppercase() } + #[pymethod] fn capitalize(self, _vm: &VirtualMachine) -> String { let (first_part, lower_str) = self.value.split_at(1); format!("{}{}", first_part.to_uppercase(), lower_str) } + #[pymethod] fn split( self, - pattern: OptionalArg, + pattern: OptionalArg, num: OptionalArg, vm: &VirtualMachine, ) -> PyObjectRef { @@ -206,9 +257,10 @@ impl PyStringRef { vm.ctx.new_list(elements) } + #[pymethod] fn rsplit( self, - pattern: OptionalArg, + pattern: OptionalArg, num: OptionalArg, vm: &VirtualMachine, ) -> PyObjectRef { @@ -227,18 +279,22 @@ impl PyStringRef { vm.ctx.new_list(elements) } + #[pymethod] fn strip(self, _vm: &VirtualMachine) -> String { self.value.trim().to_string() } + #[pymethod] fn lstrip(self, _vm: &VirtualMachine) -> String { self.value.trim_start().to_string() } + #[pymethod] fn rstrip(self, _vm: &VirtualMachine) -> String { self.value.trim_end().to_string() } + #[pymethod] fn endswith( self, suffix: PyStringRef, @@ -253,6 +309,7 @@ impl PyStringRef { } } + #[pymethod] fn startswith( self, prefix: PyStringRef, @@ -267,14 +324,17 @@ impl PyStringRef { } } + #[pymethod] fn isalnum(self, _vm: &VirtualMachine) -> bool { !self.value.is_empty() && self.value.chars().all(char::is_alphanumeric) } + #[pymethod] fn isnumeric(self, _vm: &VirtualMachine) -> bool { !self.value.is_empty() && self.value.chars().all(char::is_numeric) } + #[pymethod] fn isdigit(self, _vm: &VirtualMachine) -> bool { // python's isdigit also checks if exponents are digits, these are the unicodes for exponents let valid_unicodes: [u16; 10] = [ @@ -291,6 +351,7 @@ impl PyStringRef { } } + #[pymethod] fn isdecimal(self, _vm: &VirtualMachine) -> bool { if self.value.is_empty() { false @@ -299,10 +360,41 @@ impl PyStringRef { } } + #[pymethod] + fn format(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { + if args.args.is_empty() { + return Err(vm.new_type_error( + "descriptor 'format' of 'str' object needs an argument".to_string(), + )); + } + + let zelf = &args.args[0]; + if !objtype::isinstance(&zelf, &vm.ctx.str_type()) { + let zelf_typ = zelf.class(); + let actual_type = vm.to_pystr(&zelf_typ)?; + return Err(vm.new_type_error(format!( + "descriptor 'format' requires a 'str' object but received a '{}'", + actual_type + ))); + } + let format_string_text = get_value(zelf); + match FormatString::from_str(format_string_text.as_str()) { + Ok(format_string) => perform_format(vm, &format_string, &args), + Err(err) => match err { + FormatParseError::UnmatchedBracket => { + Err(vm.new_value_error("expected '}' before end of string".to_string())) + } + _ => Err(vm.new_value_error("Unexpected error parsing format string".to_string())), + }, + } + } + + #[pymethod] fn title(self, _vm: &VirtualMachine) -> String { make_title(&self.value) } + #[pymethod] fn swapcase(self, _vm: &VirtualMachine) -> String { let mut swapped_str = String::with_capacity(self.value.len()); for c in self.value.chars() { @@ -318,14 +410,16 @@ impl PyStringRef { swapped_str } + #[pymethod] fn isalpha(self, _vm: &VirtualMachine) -> bool { !self.value.is_empty() && self.value.chars().all(char::is_alphanumeric) } + #[pymethod] fn replace( self, - old: Self, - new: Self, + old: PyStringRef, + new: PyStringRef, num: OptionalArg, _vm: &VirtualMachine, ) -> String { @@ -337,10 +431,12 @@ impl PyStringRef { // cpython's isspace ignores whitespace, including \t and \n, etc, unless the whole string is empty // which is why isspace is using is_ascii_whitespace. Same for isupper & islower + #[pymethod] fn isspace(self, _vm: &VirtualMachine) -> bool { !self.value.is_empty() && self.value.chars().all(|c| c.is_ascii_whitespace()) } + #[pymethod] fn isupper(self, _vm: &VirtualMachine) -> bool { !self.value.is_empty() && self @@ -350,6 +446,7 @@ impl PyStringRef { .all(char::is_uppercase) } + #[pymethod] fn islower(self, _vm: &VirtualMachine) -> bool { !self.value.is_empty() && self @@ -359,11 +456,13 @@ impl PyStringRef { .all(char::is_lowercase) } + #[pymethod] fn isascii(self, _vm: &VirtualMachine) -> bool { !self.value.is_empty() && self.value.chars().all(|c| c.is_ascii()) } // doesn't implement keep new line delimiter just yet + #[pymethod] fn splitlines(self, vm: &VirtualMachine) -> PyObjectRef { let elements = self .value @@ -373,6 +472,7 @@ impl PyStringRef { vm.ctx.new_list(elements) } + #[pymethod] fn join(self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult { let mut joined = String::new(); @@ -387,9 +487,10 @@ impl PyStringRef { Ok(joined) } + #[pymethod] fn find( self, - sub: Self, + sub: PyStringRef, start: OptionalArg, end: OptionalArg, _vm: &VirtualMachine, @@ -405,9 +506,10 @@ impl PyStringRef { } } + #[pymethod] fn rfind( self, - sub: Self, + sub: PyStringRef, start: OptionalArg, end: OptionalArg, _vm: &VirtualMachine, @@ -423,9 +525,10 @@ impl PyStringRef { } } + #[pymethod] fn index( self, - sub: Self, + sub: PyStringRef, start: OptionalArg, end: OptionalArg, vm: &VirtualMachine, @@ -441,9 +544,10 @@ impl PyStringRef { } } + #[pymethod] fn rindex( self, - sub: Self, + sub: PyStringRef, start: OptionalArg, end: OptionalArg, vm: &VirtualMachine, @@ -459,6 +563,7 @@ impl PyStringRef { } } + #[pymethod] fn partition(self, sub: PyStringRef, vm: &VirtualMachine) -> PyObjectRef { let value = &self.value; let sub = &sub.value; @@ -477,6 +582,7 @@ impl PyStringRef { vm.ctx.new_tuple(new_tup) } + #[pymethod] fn rpartition(self, sub: PyStringRef, vm: &VirtualMachine) -> PyObjectRef { let value = &self.value; let sub = &sub.value; @@ -496,6 +602,7 @@ impl PyStringRef { vm.ctx.new_tuple(new_tup) } + #[pymethod] fn istitle(self, _vm: &VirtualMachine) -> bool { if self.value.is_empty() { false @@ -504,9 +611,10 @@ impl PyStringRef { } } + #[pymethod] fn count( self, - sub: Self, + sub: PyStringRef, start: OptionalArg, end: OptionalArg, _vm: &VirtualMachine, @@ -519,6 +627,7 @@ impl PyStringRef { } } + #[pymethod] fn zfill(self, len: usize, _vm: &VirtualMachine) -> String { let value = &self.value; if len <= value.len() { @@ -542,21 +651,39 @@ impl PyStringRef { } } - fn ljust(self, len: usize, rep: OptionalArg, vm: &VirtualMachine) -> PyResult { + #[pymethod] + fn ljust( + self, + len: usize, + rep: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { let value = &self.value; - let rep_char = PyStringRef::get_fill_char(&rep, vm)?; + let rep_char = Self::get_fill_char(&rep, vm)?; Ok(format!("{}{}", value, rep_char.repeat(len))) } - fn rjust(self, len: usize, rep: OptionalArg, vm: &VirtualMachine) -> PyResult { + #[pymethod] + fn rjust( + self, + len: usize, + rep: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { let value = &self.value; - let rep_char = PyStringRef::get_fill_char(&rep, vm)?; + let rep_char = Self::get_fill_char(&rep, vm)?; Ok(format!("{}{}", rep_char.repeat(len), value)) } - fn center(self, len: usize, rep: OptionalArg, vm: &VirtualMachine) -> PyResult { + #[pymethod] + fn center( + self, + len: usize, + rep: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { let value = &self.value; - let rep_char = PyStringRef::get_fill_char(&rep, vm)?; + let rep_char = Self::get_fill_char(&rep, vm)?; let left_buff: usize = (len - value.len()) / 2; let right_buff = len - value.len() - left_buff; Ok(format!( @@ -567,6 +694,7 @@ impl PyStringRef { )) } + #[pymethod] fn expandtabs(self, tab_stop: OptionalArg, _vm: &VirtualMachine) -> String { let tab_stop = tab_stop.into_option().unwrap_or(8 as usize); let mut expanded_str = String::new(); @@ -590,6 +718,7 @@ impl PyStringRef { expanded_str } + #[pymethod] fn isidentifier(self, _vm: &VirtualMachine) -> bool { let value = &self.value; // a string is not an identifier if it has whitespace or starts with a number @@ -620,78 +749,8 @@ impl IntoPyObject for String { } } -#[rustfmt::skip] // to avoid line splitting -pub fn init(context: &PyContext) { - let str_type = &context.str_type; - let str_doc = "str(object='') -> str\n\ - str(bytes_or_buffer[, encoding[, errors]]) -> str\n\ - \n\ - Create a new string object from the given object. If encoding or\n\ - errors is specified, then the object must expose a data buffer\n\ - that will be decoded using the given encoding and error handler.\n\ - Otherwise, returns the result of object.__str__() (if defined)\n\ - or repr(object).\n\ - encoding defaults to sys.getdefaultencoding().\n\ - errors defaults to 'strict'."; - - extend_class!(context, str_type, { - "__add__" => context.new_rustfunc(PyStringRef::add), - "__bool__" => context.new_rustfunc(PyStringRef::bool), - "__contains__" => context.new_rustfunc(PyStringRef::contains), - "__doc__" => context.new_str(str_doc.to_string()), - "__eq__" => context.new_rustfunc(PyStringRef::eq), - "__ge__" => context.new_rustfunc(PyStringRef::ge), - "__getitem__" => context.new_rustfunc(PyStringRef::getitem), - "__gt__" => context.new_rustfunc(PyStringRef::gt), - "__hash__" => context.new_rustfunc(PyStringRef::hash), - "__lt__" => context.new_rustfunc(PyStringRef::lt), - "__le__" => context.new_rustfunc(PyStringRef::le), - "__len__" => context.new_rustfunc(PyStringRef::len), - "__mul__" => context.new_rustfunc(PyStringRef::mul), - "__new__" => context.new_rustfunc(str_new), - "__repr__" => context.new_rustfunc(PyStringRef::repr), - "__str__" => context.new_rustfunc(PyStringRef::str), - "capitalize" => context.new_rustfunc(PyStringRef::capitalize), - "casefold" => context.new_rustfunc(PyStringRef::casefold), - "center" => context.new_rustfunc(PyStringRef::center), - "count" => context.new_rustfunc(PyStringRef::count), - "endswith" => context.new_rustfunc(PyStringRef::endswith), - "expandtabs" => context.new_rustfunc(PyStringRef::expandtabs), - "find" => context.new_rustfunc(PyStringRef::find), - "format" => context.new_rustfunc(str_format), - "index" => context.new_rustfunc(PyStringRef::index), - "isalnum" => context.new_rustfunc(PyStringRef::isalnum), - "isalpha" => context.new_rustfunc(PyStringRef::isalpha), - "isascii" => context.new_rustfunc(PyStringRef::isascii), - "isdecimal" => context.new_rustfunc(PyStringRef::isdecimal), - "isdigit" => context.new_rustfunc(PyStringRef::isdigit), - "isidentifier" => context.new_rustfunc(PyStringRef::isidentifier), - "islower" => context.new_rustfunc(PyStringRef::islower), - "isnumeric" => context.new_rustfunc(PyStringRef::isnumeric), - "isspace" => context.new_rustfunc(PyStringRef::isspace), - "isupper" => context.new_rustfunc(PyStringRef::isupper), - "istitle" => context.new_rustfunc(PyStringRef::istitle), - "join" => context.new_rustfunc(PyStringRef::join), - "lower" => context.new_rustfunc(PyStringRef::lower), - "ljust" => context.new_rustfunc(PyStringRef::ljust), - "lstrip" => context.new_rustfunc(PyStringRef::lstrip), - "partition" => context.new_rustfunc(PyStringRef::partition), - "replace" => context.new_rustfunc(PyStringRef::replace), - "rfind" => context.new_rustfunc(PyStringRef::rfind), - "rindex" => context.new_rustfunc(PyStringRef::rindex), - "rjust" => context.new_rustfunc(PyStringRef::rjust), - "rpartition" => context.new_rustfunc(PyStringRef::rpartition), - "rsplit" => context.new_rustfunc(PyStringRef::rsplit), - "rstrip" => context.new_rustfunc(PyStringRef::rstrip), - "split" => context.new_rustfunc(PyStringRef::split), - "splitlines" => context.new_rustfunc(PyStringRef::splitlines), - "startswith" => context.new_rustfunc(PyStringRef::startswith), - "strip" => context.new_rustfunc(PyStringRef::strip), - "swapcase" => context.new_rustfunc(PyStringRef::swapcase), - "title" => context.new_rustfunc(PyStringRef::title), - "upper" => context.new_rustfunc(PyStringRef::upper), - "zfill" => context.new_rustfunc(PyStringRef::zfill), - }); +pub fn init(ctx: &PyContext) { + PyStringRef::extend_class(ctx, &ctx.str_type); } pub fn get_value(obj: &PyObjectRef) -> String { @@ -706,34 +765,6 @@ fn count_char(s: &str, c: char) -> usize { s.chars().filter(|x| *x == c).count() } -fn str_format(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - if args.args.is_empty() { - return Err( - vm.new_type_error("descriptor 'format' of 'str' object needs an argument".to_string()) - ); - } - - let zelf = &args.args[0]; - if !objtype::isinstance(&zelf, &vm.ctx.str_type()) { - let zelf_typ = zelf.class(); - let actual_type = vm.to_pystr(&zelf_typ)?; - return Err(vm.new_type_error(format!( - "descriptor 'format' requires a 'str' object but received a '{}'", - actual_type - ))); - } - let format_string_text = get_value(zelf); - match FormatString::from_str(format_string_text.as_str()) { - Ok(format_string) => perform_format(vm, &format_string, &args), - Err(err) => match err { - FormatParseError::UnmatchedBracket => { - Err(vm.new_value_error("expected '}' before end of string".to_string())) - } - _ => Err(vm.new_value_error("Unexpected error parsing format string".to_string())), - }, - } -} - fn call_object_format(vm: &VirtualMachine, argument: PyObjectRef, format_spec: &str) -> PyResult { let returned_type = vm.ctx.new_str(format_spec.to_string()); let result = vm.call_method(&argument, "__format__", vec![returned_type])?; @@ -797,26 +828,6 @@ fn perform_format( Ok(vm.ctx.new_str(final_string)) } -// TODO: should with following format -// class str(object='') -// class str(object=b'', encoding='utf-8', errors='strict') -fn str_new( - cls: PyClassRef, - object: OptionalArg, - vm: &VirtualMachine, -) -> PyResult { - let string = match object { - OptionalArg::Present(ref input) => vm.to_str(input)?.into_object(), - OptionalArg::Missing => vm.new_str("".to_string()), - }; - if string.class().is(&cls) { - TryFromObject::try_from_object(vm, string) - } else { - let payload = string.payload::().unwrap(); - payload.clone().into_ref_with_type(vm, cls) - } -} - impl PySliceableSequence for String { fn do_slice(&self, range: Range) -> Self { to_graphemes(self) diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index aa3348be59..801522ea07 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -1259,6 +1259,36 @@ where } } +pub trait PyClassDef { + const NAME: &'static str; + const DOC: Option<&'static str> = None; +} + +impl PyClassDef for PyRef +where + T: PyClassDef, +{ + const NAME: &'static str = T::NAME; + const DOC: Option<&'static str> = T::DOC; +} + +pub trait PyClassImpl: PyClassDef { + fn impl_extend_class(ctx: &PyContext, class: &PyClassRef); + + fn extend_class(ctx: &PyContext, class: &PyClassRef) { + Self::impl_extend_class(ctx, class); + if let Some(doc) = Self::DOC { + ctx.set_attr(class, "__doc__", ctx.new_str(doc.into())); + } + } + + fn make_class(ctx: &PyContext) -> PyClassRef { + let py_class = ctx.new_class(Self::NAME, ctx.object()); + Self::extend_class(ctx, &py_class); + py_class + } +} + #[cfg(test)] mod tests { use super::*;