diff --git a/crates/vm/src/stdlib/ctypes/base.rs b/crates/vm/src/stdlib/ctypes/base.rs index 29f3b9da2a..6fdb79e11e 100644 --- a/crates/vm/src/stdlib/ctypes/base.rs +++ b/crates/vm/src/stdlib/ctypes/base.rs @@ -3,8 +3,9 @@ use crate::builtins::PyType; use crate::builtins::{PyBytes, PyFloat, PyInt, PyNone, PyStr, PyTypeRef}; use crate::convert::ToPyObject; use crate::function::{Either, OptionalArg}; +use crate::protocol::PyNumberMethods; use crate::stdlib::ctypes::_ctypes::new_simple_type; -use crate::types::Constructor; +use crate::types::{AsNumber, Constructor}; use crate::{AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine}; use crossbeam_utils::atomic::AtomicCell; use num_traits::ToPrimitive; @@ -158,9 +159,10 @@ pub struct PyCData { impl PyCData {} #[pyclass(module = "_ctypes", name = "PyCSimpleType", base = PyType)] +#[derive(Debug, PyPayload)] pub struct PyCSimpleType {} -#[pyclass(flags(BASETYPE))] +#[pyclass(flags(BASETYPE), with(AsNumber))] impl PyCSimpleType { #[allow(clippy::new_ret_no_self)] #[pymethod] @@ -186,6 +188,33 @@ impl PyCSimpleType { PyCSimpleType::from_param(cls, as_parameter, vm) } + + #[pymethod] + fn __mul__(cls: PyTypeRef, n: isize, vm: &VirtualMachine) -> PyResult { + PyCSimple::repeat(cls, n, vm) + } +} + +impl AsNumber for PyCSimpleType { + fn as_number() -> &'static PyNumberMethods { + static AS_NUMBER: PyNumberMethods = PyNumberMethods { + multiply: Some(|a, b, vm| { + // a is a PyCSimpleType instance (type object like c_char) + // b is int (array size) + let cls = a + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("expected type".to_owned()))?; + let n = b + .try_index(vm)? + .as_bigint() + .to_isize() + .ok_or_else(|| vm.new_overflow_error("array size too large".to_owned()))?; + PyCSimple::repeat(cls.to_owned(), n, vm) + }), + ..PyNumberMethods::NOT_IMPLEMENTED + }; + &AS_NUMBER + } } #[pyclass( @@ -215,8 +244,18 @@ impl Constructor for PyCSimple { let attributes = cls.get_attributes(); let _type_ = attributes .iter() - .find(|(k, _)| k.to_object().str(vm).unwrap().to_string() == *"_type_") - .unwrap() + .find(|(k, _)| { + k.to_object() + .str(vm) + .map(|s| s.to_string() == "_type_") + .unwrap_or(false) + }) + .ok_or_else(|| { + vm.new_type_error(format!( + "cannot create '{}' instances: no _type_ attribute", + cls.name() + )) + })? .1 .str(vm)? .to_string(); @@ -276,11 +315,6 @@ impl PyCSimple { } .to_pyobject(vm)) } - - #[pyclassmethod] - fn __mul__(cls: PyTypeRef, n: isize, vm: &VirtualMachine) -> PyResult { - PyCSimple::repeat(cls, n, vm) - } } impl PyCSimple { diff --git a/crates/vm/src/stdlib/ctypes/pointer.rs b/crates/vm/src/stdlib/ctypes/pointer.rs index b60280c73c..3eb6f68e6d 100644 --- a/crates/vm/src/stdlib/ctypes/pointer.rs +++ b/crates/vm/src/stdlib/ctypes/pointer.rs @@ -1,8 +1,13 @@ +use crossbeam_utils::atomic::AtomicCell; +use num_traits::ToPrimitive; use rustpython_common::lock::PyRwLock; -use crate::builtins::PyType; +use crate::builtins::{PyType, PyTypeRef}; +use crate::convert::ToPyObject; +use crate::protocol::PyNumberMethods; use crate::stdlib::ctypes::PyCData; -use crate::{PyObjectRef, PyResult}; +use crate::types::AsNumber; +use crate::{PyObjectRef, PyResult, VirtualMachine}; #[pyclass(name = "PyCPointerType", base = PyType, module = "_ctypes")] #[derive(PyPayload, Debug)] @@ -11,8 +16,44 @@ pub struct PyCPointerType { pub(crate) inner: PyCPointer, } -#[pyclass] -impl PyCPointerType {} +#[pyclass(flags(IMMUTABLETYPE), with(AsNumber))] +impl PyCPointerType { + #[pymethod] + fn __mul__(cls: PyTypeRef, n: isize, vm: &VirtualMachine) -> PyResult { + use super::array::{PyCArray, PyCArrayType}; + if n < 0 { + return Err(vm.new_value_error(format!("Array length must be >= 0, not {n}"))); + } + Ok(PyCArrayType { + inner: PyCArray { + typ: PyRwLock::new(cls), + length: AtomicCell::new(n as usize), + value: PyRwLock::new(vm.ctx.none()), + }, + } + .to_pyobject(vm)) + } +} + +impl AsNumber for PyCPointerType { + fn as_number() -> &'static PyNumberMethods { + static AS_NUMBER: PyNumberMethods = PyNumberMethods { + multiply: Some(|a, b, vm| { + let cls = a + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("expected type".to_owned()))?; + let n = b + .try_index(vm)? + .as_bigint() + .to_isize() + .ok_or_else(|| vm.new_overflow_error("array size too large".to_owned()))?; + PyCPointerType::__mul__(cls.to_owned(), n, vm) + }), + ..PyNumberMethods::NOT_IMPLEMENTED + }; + &AS_NUMBER + } +} #[pyclass( name = "_Pointer",