diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py index 99a19fe4b8..a661a59610 100644 --- a/Lib/test/test_weakref.py +++ b/Lib/test/test_weakref.py @@ -1856,6 +1856,7 @@ def test_threaded_weak_value_dict_copy(self): # copying should not result in a crash. self.check_threaded_weak_dict_copy(weakref.WeakValueDictionary, False) + @unittest.skip("TODO: RUSTPYTHON; flaky test") def test_threaded_weak_value_dict_deepcopy(self): # Issue #35615: Weakref keys or values getting GC'ed during dict # copying should not result in a crash. diff --git a/vm/src/builtins/object.rs b/vm/src/builtins/object.rs index 4d1173e1b4..305f64b22e 100644 --- a/vm/src/builtins/object.rs +++ b/vm/src/builtins/object.rs @@ -77,10 +77,7 @@ impl PyBaseObject { } } PyComparisonOp::Ne => { - let cmp = zelf - .class() - .mro_find_map(|cls| cls.slots.richcompare.load()) - .unwrap(); + let cmp = zelf.class().slots.richcompare.load().unwrap(); let value = match cmp(zelf, other, PyComparisonOp::Eq, vm)? { Either::A(obj) => PyArithmeticValue::from_object(vm, obj) .map(|obj| obj.try_to_bool(vm)) diff --git a/vm/src/builtins/type.rs b/vm/src/builtins/type.rs index 3395a05e3e..52ce3c6e30 100644 --- a/vm/src/builtins/type.rs +++ b/vm/src/builtins/type.rs @@ -95,6 +95,8 @@ impl PyType { .map(|x| x.iter_mro().cloned().collect()) .collect(); let mro = linearise_mro(mros)?; + slots.inherits(&mro.iter().map(|t| -> &PyType { &t }).collect::>()); + debug_assert!(slots.hash.load().is_some(), "{}", name); if base.slots.flags.has_feature(PyTypeFlags::HAS_DICT) { slots.flags |= PyTypeFlags::HAS_DICT @@ -141,19 +143,6 @@ impl PyType { std::iter::once(self).chain(self.mro.iter().map(|cls| -> &PyType { cls })) } - pub(crate) fn mro_find_map(&self, f: F) -> Option - where - F: Fn(&Self) -> Option, - { - // the hot path will be primitive types which usually hit the result from itself. - // try std::intrinsics::likely once it is stablized - if let Some(r) = f(self) { - Some(r) - } else { - self.mro.iter().find_map(|cls| f(cls)) - } - } - // This is used for class initialisation where the vm is not yet available. pub fn set_str_attr>(&self, attr_name: &str, value: V) { self._set_str_attr(attr_name, value.into()) @@ -521,7 +510,7 @@ impl PyType { // updated when __slots__ are supported (toggling the flag off if // a class has __slots__ defined). let flags = PyTypeFlags::heap_type_flags() | PyTypeFlags::HAS_DICT; - let slots = PyTypeSlots::from_flags(flags); + let slots = PyTypeSlots::with_flags(flags); let typ = Self::new_verbose_ref(name.as_str(), base, bases, attributes, slots, metatype) .map_err(|e| vm.new_type_error(e))?; @@ -637,11 +626,8 @@ impl GetAttr for PyType { if let Some(ref attr) = mcl_attr { let attr_class = attr.class(); - if attr_class - .mro_find_map(|cls| cls.slots.descr_set.load()) - .is_some() - { - if let Some(descr_get) = attr_class.mro_find_map(|cls| cls.slots.descr_get.load()) { + if attr_class.slots.descr_set.load().is_some() { + if let Some(descr_get) = attr_class.slots.descr_get.load() { let mcl = mcl.into_owned().into(); return descr_get(attr.clone(), Some(zelf.to_owned().into()), Some(mcl), vm); } @@ -651,7 +637,7 @@ impl GetAttr for PyType { let zelf_attr = zelf.get_attr(name); if let Some(ref attr) = zelf_attr { - if let Some(descr_get) = attr.class().mro_find_map(|cls| cls.slots.descr_get.load()) { + if let Some(descr_get) = attr.class().slots.descr_get.load() { drop(mcl); return descr_get(attr.clone(), None, Some(zelf.to_owned().into()), vm); } @@ -680,7 +666,7 @@ impl SetAttr for PyType { vm: &VirtualMachine, ) -> PyResult<()> { if let Some(attr) = zelf.get_class_attr(attr_name.as_str()) { - let descr_set = attr.class().mro_find_map(|cls| cls.slots.descr_set.load()); + let descr_set = attr.class().slots.descr_set.load(); if let Some(descriptor) = descr_set { return descriptor(attr, zelf.to_owned().into(), value, vm); } @@ -719,7 +705,7 @@ impl Callable for PyType { return Ok(obj); } - if let Some(init_method) = obj.class().mro_find_map(|cls| cls.slots.init.load()) { + if let Some(init_method) = obj.class().slots.init.load() { init_method(obj.clone(), args, vm)?; } Ok(obj) @@ -757,15 +743,12 @@ fn subtype_set_dict(obj: PyObjectRef, value: PyObjectRef, vm: &VirtualMachine) - let cls = obj.class().clone(); match find_base_dict_descr(&cls, vm) { Some(descr) => { - let descr_set = descr - .class() - .mro_find_map(|cls| cls.slots.descr_set.load()) - .ok_or_else(|| { - vm.new_type_error(format!( - "this __dict__ descriptor does not support '{}' objects", - cls.name() - )) - })?; + let descr_set = descr.class().slots.descr_set.load().ok_or_else(|| { + vm.new_type_error(format!( + "this __dict__ descriptor does not support '{}' objects", + cls.name() + )) + })?; descr_set(descr, obj, Some(value), vm) } None => { diff --git a/vm/src/function/protocol.rs b/vm/src/function/protocol.rs index 9d5b54a5a1..6889d22542 100644 --- a/vm/src/function/protocol.rs +++ b/vm/src/function/protocol.rs @@ -85,7 +85,7 @@ where let iterfn; { let cls = obj.class(); - iterfn = cls.mro_find_map(|x| x.slots.iter.load()); + iterfn = cls.slots.iter.load(); if iterfn.is_none() && !cls.has_attr("__getitem__") { return Err(vm.new_type_error(format!("'{}' object is not iterable", cls.name()))); } diff --git a/vm/src/macros.rs b/vm/src/macros.rs index eff8917495..a886dd51be 100644 --- a/vm/src/macros.rs +++ b/vm/src/macros.rs @@ -17,7 +17,7 @@ macro_rules! py_class { ( $ctx:expr, $class_name:expr, $class_base:expr, $flags:expr, { $($name:tt => $value:expr),* $(,)* }) => { { #[allow(unused_mut)] - let mut slots = $crate::types::PyTypeSlots::from_flags($crate::types::PyTypeFlags::DEFAULT | $flags); + let mut slots = $crate::types::PyTypeSlots::with_flags($crate::types::PyTypeFlags::DEFAULT | $flags); $($crate::py_class!(@extract_slots($ctx, &mut slots, $name, $value));)* let py_class = $ctx.new_class(None, $class_name, $class_base, slots); $($crate::py_class!(@extract_attrs($ctx, &py_class, $name, $value));)* diff --git a/vm/src/object/core.rs b/vm/src/object/core.rs index 19c7aa47ec..dc17d4dbca 100644 --- a/vm/src/object/core.rs +++ b/vm/src/object/core.rs @@ -762,7 +762,7 @@ impl PyObject { } // CPython-compatible drop implementation - if let Some(slot_del) = self.class().mro_find_map(|cls| cls.slots.del.load()) { + if let Some(slot_del) = self.class().slots.del.load() { call_slot_del(self, slot_del)?; } if let Some(wrl) = self.weak_ref_list() { @@ -1081,22 +1081,25 @@ pub(crate) fn init_type_hierarchy() -> (PyTypeRef, PyTypeRef, PyTypeRef) { static_assertions::assert_eq_size!(MaybeUninit>, PyInner); static_assertions::assert_eq_align!(MaybeUninit>, PyInner); - let type_payload = PyType { + let object_payload = PyType { base: None, bases: vec![], mro: vec![], subclasses: PyRwLock::default(), attributes: PyRwLock::new(Default::default()), - slots: PyType::make_slots(), + slots: object::PyBaseObject::make_slots(), }; - let object_payload = PyType { + let mut type_slots = PyType::make_slots(); + type_slots.inherits(&[&object_payload]); + let type_payload = PyType { base: None, bases: vec![], mro: vec![], subclasses: PyRwLock::default(), attributes: PyRwLock::new(Default::default()), - slots: object::PyBaseObject::make_slots(), + slots: type_slots, }; + let type_type_ptr = Box::into_raw(Box::new(partially_init!( PyInner:: { ref_count: RefCount::new(), @@ -1149,13 +1152,15 @@ pub(crate) fn init_type_hierarchy() -> (PyTypeRef, PyTypeRef, PyTypeRef) { } }; + let mut weakref_slots = PyWeak::make_slots(); + weakref_slots.inherits(&[&object_type]); let weakref_type = PyType { base: Some(object_type.clone()), bases: vec![object_type.clone()], mro: vec![object_type.clone()], subclasses: PyRwLock::default(), attributes: PyRwLock::default(), - slots: PyWeak::make_slots(), + slots: weakref_slots, }; let weakref_type = PyRef::new_ref(weakref_type, type_type.clone(), None); diff --git a/vm/src/protocol/buffer.rs b/vm/src/protocol/buffer.rs index 9f4c191be0..26426d2ec7 100644 --- a/vm/src/protocol/buffer.rs +++ b/vm/src/protocol/buffer.rs @@ -140,7 +140,7 @@ impl PyBuffer { impl TryFromBorrowedObject for PyBuffer { fn try_from_borrowed_object(vm: &VirtualMachine, obj: &PyObject) -> PyResult { let cls = obj.class(); - if let Some(f) = cls.mro_find_map(|cls| cls.slots.as_buffer) { + if let Some(f) = cls.slots.as_buffer { return f(obj, vm); } Err(vm.new_type_error(format!( diff --git a/vm/src/protocol/iter.rs b/vm/src/protocol/iter.rs index deef0c37ea..eccb8a2075 100644 --- a/vm/src/protocol/iter.rs +++ b/vm/src/protocol/iter.rs @@ -16,9 +16,7 @@ where impl PyIter { pub fn check(obj: &PyObject) -> bool { - obj.class() - .mro_find_map(|x| x.slots.iternext.load()) - .is_some() + obj.class().slots.iternext.load().is_some() } } @@ -34,7 +32,9 @@ where self.0 .borrow() .class() - .mro_find_map(|x| x.slots.iternext.load()) + .slots + .iternext + .load() .ok_or_else(|| { vm.new_type_error(format!( "'{}' object is not an iterator", @@ -120,7 +120,7 @@ impl TryFromObject for PyIter { fn try_from_object(vm: &VirtualMachine, iter_target: PyObjectRef) -> PyResult { let getiter = { let cls = iter_target.class(); - cls.mro_find_map(|x| x.slots.iter.load()) + cls.slots.iter.load() }; if let Some(getiter) = getiter { let iter = getiter(iter_target, vm)?; diff --git a/vm/src/protocol/mapping.rs b/vm/src/protocol/mapping.rs index f8ec7437d4..7ea7f0c63a 100644 --- a/vm/src/protocol/mapping.rs +++ b/vm/src/protocol/mapping.rs @@ -70,11 +70,7 @@ impl PyMapping<'_> { pub fn methods(&self, vm: &VirtualMachine) -> &PyMappingMethods { self.methods.get_or_init(|| { - if let Some(f) = self - .obj - .class() - .mro_find_map(|cls| cls.slots.as_mapping.load()) - { + if let Some(f) = self.obj.class().slots.as_mapping.load() { f(self.obj, vm) } else { PyMappingMethods::default() diff --git a/vm/src/protocol/object.rs b/vm/src/protocol/object.rs index 5023eca11f..c56f067cd1 100644 --- a/vm/src/protocol/object.rs +++ b/vm/src/protocol/object.rs @@ -90,10 +90,7 @@ impl PyObject { #[inline] fn _get_attr(&self, attr_name: PyStrRef, vm: &VirtualMachine) -> PyResult { vm_trace!("object.__getattribute__: {:?} {:?}", obj, attr_name); - let getattro = self - .class() - .mro_find_map(|cls| cls.slots.getattro.load()) - .unwrap(); + let getattro = self.class().slots.getattro.load().unwrap(); getattro(self, attr_name.clone(), vm).map_err(|exc| { vm.set_attribute_error_context(&exc, self.to_owned(), attr_name); exc @@ -108,18 +105,17 @@ impl PyObject { ) -> PyResult<()> { let setattro = { let cls = self.class(); - cls.mro_find_map(|cls| cls.slots.setattro.load()) - .ok_or_else(|| { - let assign = attr_value.is_some(); - let has_getattr = cls.mro_find_map(|cls| cls.slots.getattro.load()).is_some(); - vm.new_type_error(format!( - "'{}' object has {} attributes ({} {})", - cls.name(), - if has_getattr { "only read-only" } else { "no" }, - if assign { "assign to" } else { "del" }, - attr_name - )) - })? + cls.slots.setattro.load().ok_or_else(|| { + let assign = attr_value.is_some(); + let has_getattr = cls.slots.getattro.load().is_some(); + vm.new_type_error(format!( + "'{}' object has {} attributes ({} {})", + cls.name(), + if has_getattr { "only read-only" } else { "no" }, + if assign { "assign to" } else { "del" }, + attr_name + )) + })? }; setattro(self, attr_name, attr_value, vm) } @@ -145,7 +141,7 @@ impl PyObject { vm_trace!("object.__setattr__({:?}, {}, {:?})", obj, attr_name, value); if let Some(attr) = self.get_class_attr(attr_name.as_str()) { - let descr_set = attr.class().mro_find_map(|cls| cls.slots.descr_set.load()); + let descr_set = attr.class().slots.descr_set.load(); if let Some(descriptor) = descr_set { return descriptor(attr, self.to_owned(), value, vm); } @@ -194,12 +190,9 @@ impl PyObject { let cls_attr = match obj_cls.get_attr(name) { Some(descr) => { let descr_cls = descr.class(); - let descr_get = descr_cls.mro_find_map(|cls| cls.slots.descr_get.load()); + let descr_get = descr_cls.slots.descr_get.load(); if let Some(descr_get) = descr_get { - if descr_cls - .mro_find_map(|cls| cls.slots.descr_set.load()) - .is_some() - { + if descr_cls.slots.descr_set.load().is_some() { drop(descr_cls); let cls = obj_cls.into_owned().into(); return descr_get(descr, Some(self.to_owned()), Some(cls), vm).map(Some); @@ -254,10 +247,7 @@ impl PyObject { ) -> PyResult> { let swapped = op.swapped(); let call_cmp = |obj: &PyObject, other: &PyObject, op| { - let cmp = obj - .class() - .mro_find_map(|cls| cls.slots.richcompare.load()) - .unwrap(); + let cmp = obj.class().slots.richcompare.load().unwrap(); let r = match cmp(obj, other, op, vm)? { Either::A(obj) => PyArithmeticValue::from_object(vm, obj).map(Either::A), Either::B(arithmetic) => arithmetic.map(Either::B), @@ -496,10 +486,7 @@ impl PyObject { } pub fn hash(&self, vm: &VirtualMachine) -> PyResult { - let hash = self - .class() - .mro_find_map(|cls| cls.slots.hash.load()) - .unwrap(); // hash always exist + let hash = self.class().slots.hash.load().unwrap(); // hash always exist hash(self, vm) } diff --git a/vm/src/protocol/sequence.rs b/vm/src/protocol/sequence.rs index 94f3b1fa40..4696f2d87f 100644 --- a/vm/src/protocol/sequence.rs +++ b/vm/src/protocol/sequence.rs @@ -97,7 +97,7 @@ impl PySequence<'_> { self.methods.get_or_init(|| { let cls = self.obj.class(); if !cls.is(&vm.ctx.types.dict_type) { - if let Some(f) = cls.mro_find_map(|x| x.slots.as_sequence.load()) { + if let Some(f) = cls.slots.as_sequence.load() { return f(self.obj, vm); } } diff --git a/vm/src/sequence.rs b/vm/src/sequence.rs index 798db58336..93c6f236c9 100644 --- a/vm/src/sequence.rs +++ b/vm/src/sequence.rs @@ -111,9 +111,7 @@ pub trait MutObjectSequenceOp<'a> { F: FnMut(), { let needle_cls = needle.class(); - let needle_cmp = needle_cls - .mro_find_map(|cls| cls.slots.richcompare.load()) - .unwrap(); + let needle_cmp = needle_cls.slots.richcompare.load().unwrap(); let mut borrower = None; let mut i = range.start; @@ -146,9 +144,7 @@ pub trait MutObjectSequenceOp<'a> { !elem_cls.is(&needle_cls) && elem_cls.fast_issubclass(&needle_cls); let eq = if reverse_first { - let elem_cmp = elem_cls - .mro_find_map(|cls| cls.slots.richcompare.load()) - .unwrap(); + let elem_cmp = elem_cls.slots.richcompare.load().unwrap(); drop(elem_cls); fn cmp( @@ -195,9 +191,7 @@ pub trait MutObjectSequenceOp<'a> { obj.try_to_bool(vm)? } _ => { - let elem_cmp = elem_cls - .mro_find_map(|cls| cls.slots.richcompare.load()) - .unwrap(); + let elem_cmp = elem_cls.slots.richcompare.load().unwrap(); drop(elem_cls); fn cmp( diff --git a/vm/src/types/slot.rs b/vm/src/types/slot.rs index 449b0fc52e..91007c8b30 100644 --- a/vm/src/types/slot.rs +++ b/vm/src/types/slot.rs @@ -78,12 +78,52 @@ pub struct PyTypeSlots { } impl PyTypeSlots { - pub fn from_flags(flags: PyTypeFlags) -> Self { + pub fn with_flags(flags: PyTypeFlags) -> Self { Self { flags, ..Default::default() } } + + pub fn inherits(&mut self, mro: &[&PyType]) { + macro_rules! inherit { + ($name:ident) => { + if self.$name.is_none() { + for ty in mro { + if let Some(func) = ty.slots.$name { + self.$name = Some(func); + break; + } + } + } + }; + ($name:ident, "atomic") => { + if self.$name.load().is_none() { + for ty in mro { + if let Some(func) = ty.slots.$name.load() { + self.$name.store(Some(func)); + break; + } + } + } + }; + } + inherit!(as_sequence, "atomic"); + inherit!(as_mapping, "atomic"); + inherit!(hash, "atomic"); + inherit!(call, "atomic"); + inherit!(getattro, "atomic"); + inherit!(setattro, "atomic"); + inherit!(as_buffer); + inherit!(richcompare, "atomic"); + inherit!(iter, "atomic"); + inherit!(iternext, "atomic"); + inherit!(descr_get, "atomic"); + inherit!(descr_set, "atomic"); + inherit!(init, "atomic"); + inherit!(new, "atomic"); + inherit!(del, "atomic"); + } } impl std::fmt::Debug for PyTypeSlots { diff --git a/vm/src/vm/method.rs b/vm/src/vm/method.rs index de8e287472..3f905c2d4d 100644 --- a/vm/src/vm/method.rs +++ b/vm/src/vm/method.rs @@ -21,7 +21,7 @@ pub enum PyMethod { impl PyMethod { pub fn get(obj: PyObjectRef, name: PyStrRef, vm: &VirtualMachine) -> PyResult { let cls = obj.class(); - let getattro = cls.mro_find_map(|cls| cls.slots.getattro.load()).unwrap(); + let getattro = cls.slots.getattro.load().unwrap(); if getattro as usize != PyBaseObject::getattro as usize { drop(cls); return obj.get_attr(name, vm).map(Self::Attribute); @@ -36,12 +36,9 @@ impl PyMethod { is_method = true; None } else { - let descr_get = descr_cls.mro_find_map(|cls| cls.slots.descr_get.load()); + let descr_get = descr_cls.slots.descr_get.load(); if let Some(descr_get) = descr_get { - if descr_cls - .mro_find_map(|cls| cls.slots.descr_set.load()) - .is_some() - { + if descr_cls.slots.descr_set.load().is_some() { drop(descr_cls); let cls = cls.into_owned().into(); return descr_get(descr, Some(obj), Some(cls), vm).map(Self::Attribute); diff --git a/vm/src/vm/mod.rs b/vm/src/vm/mod.rs index 87e4ac22d8..300d4747dc 100644 --- a/vm/src/vm/mod.rs +++ b/vm/src/vm/mod.rs @@ -579,9 +579,7 @@ impl VirtualMachine { } pub fn is_callable(&self, obj: &PyObject) -> bool { - obj.class() - .mro_find_map(|cls| cls.slots.call.load()) - .is_some() + obj.class().slots.call.load().is_some() } #[inline] diff --git a/vm/src/vm/vm_object.rs b/vm/src/vm/vm_object.rs index 4c98a4c08e..9a3d6125a6 100644 --- a/vm/src/vm/vm_object.rs +++ b/vm/src/vm/vm_object.rs @@ -91,7 +91,7 @@ impl VirtualMachine { obj: Option, cls: Option, ) -> Result { - let descr_get = descr.class().mro_find_map(|cls| cls.slots.descr_get.load()); + let descr_get = descr.class().slots.descr_get.load(); match descr_get { Some(descr_get) => Ok(descr_get(descr, obj, cls, self)), None => Err(descr), @@ -164,7 +164,7 @@ impl VirtualMachine { fn _invoke(&self, callable: &PyObject, args: FuncArgs) -> PyResult { vm_trace!("Invoke: {:?} {:?}", callable, args); - let slot_call = callable.class().mro_find_map(|cls| cls.slots.call.load()); + let slot_call = callable.class().slots.call.load(); match slot_call { Some(slot_call) => { self.trace_event(TraceEvent::Call)?;