From ed8c1e1a955725093c1ab04ea6b4e2c9c52e94ac Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Fri, 29 Mar 2019 16:39:52 +0300 Subject: [PATCH 01/11] Implement set and frozenset with PySetInner --- tests/snippets/set.py | 7 + vm/src/obj/objset.rs | 1100 +++++++++++++++++++++++------------------ 2 files changed, 617 insertions(+), 490 deletions(-) diff --git a/tests/snippets/set.py b/tests/snippets/set.py index 9e49dff89e..0117163350 100644 --- a/tests/snippets/set.py +++ b/tests/snippets/set.py @@ -205,6 +205,13 @@ def __hash__(self): assert_raises(TypeError, lambda: frozenset([[]])) +a = frozenset([1,2,3]) +b = set() +for e in a: + assert e == 1 or e == 2 or e == 3 + b.add(e) +assert a == b + # set and frozen set assert frozenset([1,2,3]).union(set([4,5])) == frozenset([1,2,3,4,5]) assert set([1,2,3]).union(frozenset([4,5])) == set([1,2,3,4,5]) diff --git a/vm/src/obj/objset.rs b/vm/src/obj/objset.rs index 34d487fbb9..eaa0cd09ca 100644 --- a/vm/src/obj/objset.rs +++ b/vm/src/obj/objset.rs @@ -7,7 +7,7 @@ use std::collections::{hash_map::DefaultHasher, HashMap}; use std::fmt; use std::hash::{Hash, Hasher}; -use crate::function::{OptionalArg, PyFuncArgs}; +use crate::function::OptionalArg; use crate::pyobject::{PyContext, PyObject, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol}; use crate::vm::{ReprGuard, VirtualMachine}; @@ -20,13 +20,13 @@ use super::objtype::PyClassRef; #[derive(Default)] pub struct PySet { - elements: RefCell>, + inner: RefCell, } pub type PySetRef = PyRef; #[derive(Default)] pub struct PyFrozenSet { - elements: HashMap, + inner: PySetInner, } pub type PyFrozenSetRef = PyRef; @@ -56,536 +56,655 @@ impl PyValue for PyFrozenSet { } } -pub fn get_elements(obj: &PyObjectRef) -> HashMap { - if let Some(set) = obj.payload::() { - return set.elements.borrow().clone(); - } else if let Some(frozenset) = obj.payload::() { - return frozenset.elements.clone(); - } - panic!("Not frozenset or set"); +#[derive(Default, Clone)] +struct PySetInner { + elements: HashMap, } -fn validate_set_or_frozenset(vm: &VirtualMachine, cls: PyClassRef) -> PyResult<()> { - if !(objtype::issubclass(&cls, &vm.ctx.set_type()) - || objtype::issubclass(&cls, &vm.ctx.frozenset_type())) - { - return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", cls))); +impl PySetInner { + fn len(&self) -> usize { + self.elements.len() } - Ok(()) -} - -fn create_set( - vm: &VirtualMachine, - elements: HashMap, - cls: PyClassRef, -) -> PyResult { - if objtype::issubclass(&cls, &vm.ctx.set_type()) { - Ok(PyObject::new( - PySet { - elements: RefCell::new(elements), - }, - PySet::class(vm), - None, - )) - } else if objtype::issubclass(&cls, &vm.ctx.frozenset_type()) { - Ok(PyObject::new( - PyFrozenSet { elements: elements }, - PyFrozenSet::class(vm), - None, - )) - } else { - Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", cls))) + fn copy(&self) -> PySetInner { + PySetInner { + elements: self.elements.clone(), + } } -} - -fn perform_action_with_hash( - vm: &VirtualMachine, - elements: &mut HashMap, - item: &PyObjectRef, - f: &Fn(&VirtualMachine, &mut HashMap, u64, &PyObjectRef) -> PyResult, -) -> PyResult { - let hash: PyObjectRef = vm.call_method(item, "__hash__", vec![])?; - let hash_value = objint::get_value(&hash); - let mut hasher = DefaultHasher::new(); - hash_value.hash(&mut hasher); - let key = hasher.finish(); - f(vm, elements, key, item) -} + fn contains(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { + for element in self.elements.iter() { + let value = vm._eq(needle.clone(), element.1.clone())?; + if objbool::get_value(&value) { + return Ok(vm.new_bool(true)); + } + } + Ok(vm.new_bool(false)) + } -fn insert_into_set( - vm: &VirtualMachine, - elements: &mut HashMap, - item: &PyObjectRef, -) -> PyResult { - fn insert( + fn _compare_inner( + &self, + other: &PySetInner, + size_func: &Fn(usize, usize) -> bool, + swap: bool, vm: &VirtualMachine, - elements: &mut HashMap, - key: u64, - value: &PyObjectRef, ) -> PyResult { - elements.insert(key, value.clone()); - Ok(vm.get_none()) + let get_zelf = |swap: bool| -> &PySetInner { + if swap { + other + } else { + self + } + }; + let get_other = |swap: bool| -> &PySetInner { + if swap { + self + } else { + other + } + }; + + if size_func(get_zelf(swap).len(), get_other(swap).len()) { + return Ok(vm.new_bool(false)); + } + for element in get_other(swap).elements.iter() { + let value = get_zelf(swap).contains(element.1.clone(), vm)?; + if !objbool::get_value(&value) { + return Ok(vm.new_bool(false)); + } + } + Ok(vm.new_bool(true)) } - perform_action_with_hash(vm, elements, item, &insert) -} -fn set_add(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - trace!("set.add called with: {:?}", args); - arg_check!( - vm, - args, - required = [(zelf, Some(vm.ctx.set_type())), (item, None)] - ); - match zelf.payload::() { - Some(set) => insert_into_set(vm, &mut set.elements.borrow_mut(), item), - _ => Err(vm.new_type_error("set.add is called with no item".to_string())), + fn eq(&self, other: &PySetInner, vm: &VirtualMachine) -> PyResult { + self._compare_inner( + other, + &|zelf: usize, other: usize| -> bool { zelf != other }, + false, + vm, + ) } -} -fn set_remove(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - trace!("set.remove called with: {:?}", args); - arg_check!( - vm, - args, - required = [(s, Some(vm.ctx.set_type())), (item, None)] - ); - match s.payload::() { - Some(set) => { - fn remove( - vm: &VirtualMachine, - elements: &mut HashMap, - key: u64, - value: &PyObjectRef, - ) -> PyResult { - match elements.remove(&key) { - None => { - let item_str = format!("{:?}", value); - Err(vm.new_key_error(item_str)) - } - Some(_) => Ok(vm.get_none()), - } - } - perform_action_with_hash(vm, &mut set.elements.borrow_mut(), item, &remove) - } - _ => Err(vm.new_type_error("set.remove is called with no item".to_string())), + fn ge(&self, other: &PySetInner, vm: &VirtualMachine) -> PyResult { + self._compare_inner( + other, + &|zelf: usize, other: usize| -> bool { zelf < other }, + false, + vm, + ) } -} -fn set_discard(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - trace!("set.discard called with: {:?}", args); - arg_check!( - vm, - args, - required = [(s, Some(vm.ctx.set_type())), (item, None)] - ); - match s.payload::() { - Some(set) => { - fn discard( - vm: &VirtualMachine, - elements: &mut HashMap, - key: u64, - _value: &PyObjectRef, - ) -> PyResult { - elements.remove(&key); - Ok(vm.get_none()) - } - perform_action_with_hash(vm, &mut set.elements.borrow_mut(), item, &discard) - } - None => Err(vm.new_type_error("set.discard is called with no item".to_string())), + fn gt(&self, other: &PySetInner, vm: &VirtualMachine) -> PyResult { + self._compare_inner( + other, + &|zelf: usize, other: usize| -> bool { zelf <= other }, + false, + vm, + ) } -} -fn set_clear(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - trace!("set.clear called"); - arg_check!(vm, args, required = [(s, Some(vm.ctx.set_type()))]); - match s.payload::() { - Some(set) => { - set.elements.borrow_mut().clear(); - Ok(vm.get_none()) - } - None => Err(vm.new_type_error("".to_string())), + fn le(&self, other: &PySetInner, vm: &VirtualMachine) -> PyResult { + self._compare_inner( + other, + &|zelf: usize, other: usize| -> bool { zelf < other }, + true, + vm, + ) } -} -/* Create a new object of sub-type of set */ -fn set_new(cls: PyClassRef, iterable: OptionalArg, vm: &VirtualMachine) -> PyResult { - validate_set_or_frozenset(vm, cls.clone())?; + fn lt(&self, other: &PySetInner, vm: &VirtualMachine) -> PyResult { + self._compare_inner( + other, + &|zelf: usize, other: usize| -> bool { zelf <= other }, + true, + vm, + ) + } - let elements: HashMap = match iterable { - OptionalArg::Missing => HashMap::new(), - OptionalArg::Present(iterable) => { - let mut elements = HashMap::new(); - let iterator = objiter::get_iter(vm, &iterable)?; - while let Ok(v) = vm.call_method(&iterator, "__next__", vec![]) { - insert_into_set(vm, &mut elements, &v)?; + fn union(&self, other: &PySetInner, _vm: &VirtualMachine) -> PyResult { + let mut elements = self.elements.clone(); + elements.extend(other.elements.clone()); + + Ok(PySetInner { elements }) + } + + fn intersection(&self, other: &PySetInner, vm: &VirtualMachine) -> PyResult { + self._combine_inner(other, vm, SetCombineOperation::Intersection) + } + + fn difference(&self, other: &PySetInner, vm: &VirtualMachine) -> PyResult { + self._combine_inner(other, vm, SetCombineOperation::Difference) + } + + fn _combine_inner( + &self, + other: &PySetInner, + vm: &VirtualMachine, + op: SetCombineOperation, + ) -> PyResult { + let mut elements = HashMap::new(); + + for element in self.elements.iter() { + let value = other.contains(element.1.clone(), vm)?; + let should_add = match op { + SetCombineOperation::Intersection => objbool::get_value(&value), + SetCombineOperation::Difference => !objbool::get_value(&value), + }; + if should_add { + elements.insert(element.0.clone(), element.1.clone()); } - elements } - }; - create_set(vm, elements, cls.clone()) -} + Ok(PySetInner { elements }) + } -fn set_len(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - trace!("set.len called with: {:?}", args); - arg_check!(vm, args, required = [(s, None)]); - validate_set_or_frozenset(vm, s.class())?; - let elements = get_elements(s); - Ok(vm.context().new_int(elements.len())) -} + fn symmetric_difference( + &self, + other: &PySetInner, + vm: &VirtualMachine, + ) -> PyResult { + let mut elements = HashMap::new(); -fn set_copy(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { - trace!("set.copy called with: {:?}", obj); - validate_set_or_frozenset(vm, obj.class())?; - let elements = get_elements(&obj).clone(); - create_set(vm, elements, obj.class()) -} + for element in self.elements.iter() { + let value = other.contains(element.1.clone(), vm)?; + if !objbool::get_value(&value) { + elements.insert(element.0.clone(), element.1.clone()); + } + } + + for element in other.elements.iter() { + let value = self.contains(element.1.clone(), vm)?; + if !objbool::get_value(&value) { + elements.insert(element.0.clone(), element.1.clone()); + } + } + + Ok(PySetInner { elements }) + } -fn set_repr(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(o, Some(vm.ctx.set_type()))]); + fn iter(&self, vm: &VirtualMachine) -> PyListIterator { + let items = self.elements.values().cloned().collect(); + let set_list = vm.ctx.new_list(items); + PyListIterator { + position: Cell::new(0), + list: set_list.downcast().unwrap(), + } + } - let elements = get_elements(o); - let s = if elements.is_empty() { - "set()".to_string() - } else if let Some(_guard) = ReprGuard::enter(o) { + fn repr(&self, vm: &VirtualMachine) -> PyResult { let mut str_parts = vec![]; - for elem in elements.values() { + for elem in self.elements.values() { let part = vm.to_repr(elem)?; str_parts.push(part.value.clone()); } - format!("{{{}}}", str_parts.join(", ")) - } else { - "set(...)".to_string() - }; - Ok(vm.new_str(s)) -} + Ok(format!("{{{}}}", str_parts.join(", "))) + } + + fn add(&mut self, item: &PyObjectRef, vm: &VirtualMachine) -> PyResult { + insert_into_set(vm, &mut self.elements, &item) + } -pub fn set_contains(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(set, None), (needle, None)]); - validate_set_or_frozenset(vm, set.class())?; - for element in get_elements(set).iter() { - match vm._eq(needle.clone(), element.1.clone()) { - Ok(value) => { - if objbool::get_value(&value) { - return Ok(vm.new_bool(true)); + fn remove(&mut self, item: &PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn remove( + vm: &VirtualMachine, + elements: &mut HashMap, + key: u64, + value: &PyObjectRef, + ) -> PyResult { + match elements.remove(&key) { + None => { + let item_str = format!("{:?}", value); + Err(vm.new_key_error(item_str)) } + Some(_) => Ok(vm.get_none()), } - Err(_) => return Err(vm.new_type_error("".to_string())), } + perform_action_with_hash(vm, &mut self.elements, &item, &remove) } - Ok(vm.new_bool(false)) -} + fn discard(&mut self, item: &PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn discard( + vm: &VirtualMachine, + elements: &mut HashMap, + key: u64, + _value: &PyObjectRef, + ) -> PyResult { + elements.remove(&key); + Ok(vm.get_none()) + } + perform_action_with_hash(vm, &mut self.elements, &item, &discard) + } -fn set_eq(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - set_compare_inner( - vm, - args, - &|zelf: usize, other: usize| -> bool { zelf != other }, - false, - ) -} + fn clear(&mut self, vm: &VirtualMachine) -> PyResult { + self.elements.clear(); + Ok(vm.get_none()) + } -fn set_ge(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - set_compare_inner( - vm, - args, - &|zelf: usize, other: usize| -> bool { zelf < other }, - false, - ) -} + fn pop(&mut self, vm: &VirtualMachine) -> PyResult { + let elements = &mut self.elements; + match elements.clone().keys().next() { + Some(key) => Ok(elements.remove(key).unwrap()), + None => Err(vm.new_key_error("pop from an empty set".to_string())), + } + } -fn set_gt(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - set_compare_inner( - vm, - args, - &|zelf: usize, other: usize| -> bool { zelf <= other }, - false, - ) -} + fn update(&mut self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let iterator = objiter::get_iter(vm, &iterable)?; + while let Ok(v) = vm.call_method(&iterator, "__next__", vec![]) { + insert_into_set(vm, &mut self.elements, &v)?; + } + Ok(vm.get_none()) + } -fn set_le(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - set_compare_inner( - vm, - args, - &|zelf: usize, other: usize| -> bool { zelf < other }, - true, - ) -} + fn combine_update_inner( + &mut self, + iterable: &PyObjectRef, + vm: &VirtualMachine, + op: SetCombineOperation, + ) -> PyResult { + let elements = &mut self.elements; + for element in elements.clone().iter() { + let value = vm.call_method(iterable, "__contains__", vec![element.1.clone()])?; + let should_remove = match op { + SetCombineOperation::Intersection => !objbool::get_value(&value), + SetCombineOperation::Difference => objbool::get_value(&value), + }; + if should_remove { + elements.remove(&element.0.clone()); + } + } + Ok(vm.get_none()) + } + + fn ixor(&mut self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let elements_original = self.elements.clone(); + let iterator = objiter::get_iter(vm, &iterable)?; + while let Ok(v) = vm.call_method(&iterator, "__next__", vec![]) { + insert_into_set(vm, &mut self.elements, &v)?; + } + for element in elements_original.iter() { + let value = vm.call_method(&iterable, "__contains__", vec![element.1.clone()])?; + if objbool::get_value(&value) { + self.elements.remove(&element.0.clone()); + } + } -fn set_lt(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - set_compare_inner( - vm, - args, - &|zelf: usize, other: usize| -> bool { zelf <= other }, - true, - ) + Ok(vm.get_none()) + } } -fn set_compare_inner( - vm: &VirtualMachine, - args: PyFuncArgs, - size_func: &Fn(usize, usize) -> bool, - swap: bool, -) -> PyResult { - arg_check!(vm, args, required = [(zelf, None), (other, None)]); +impl PySetRef { + fn len(self, _vm: &VirtualMachine) -> usize { + self.inner.borrow().len() + } + fn copy(self, _vm: &VirtualMachine) -> PySet { + PySet { + inner: RefCell::new(self.inner.borrow().copy()), + } + } + fn contains(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().contains(needle, vm) + } - validate_set_or_frozenset(vm, zelf.class())?; - validate_set_or_frozenset(vm, other.class())?; + fn eq(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().eq(&get_inner(vm, &other)?, vm) + } - let get_zelf = |swap: bool| -> &PyObjectRef { - if swap { - other - } else { - zelf - } - }; - let get_other = |swap: bool| -> &PyObjectRef { - if swap { - zelf + fn ge(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().ge(&get_inner(vm, &other)?, vm) + } + + fn gt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().gt(&get_inner(vm, &other)?, vm) + } + + fn le(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().le(&get_inner(vm, &other)?, vm) + } + + fn lt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().lt(&get_inner(vm, &other)?, vm) + } + + fn union(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Ok(PyObject::new( + PySet { + inner: RefCell::new(self.inner.borrow().union(&get_inner(vm, &other)?, vm)?), + }, + PySet::class(vm), + None, + )) + } + + fn intersection(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Ok(PyObject::new( + PySet { + inner: RefCell::new( + self.inner + .borrow() + .intersection(&get_inner(vm, &other)?, vm)?, + ), + }, + PySet::class(vm), + None, + )) + } + + fn difference(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Ok(PyObject::new( + PySet { + inner: RefCell::new( + self.inner + .borrow() + .difference(&get_inner(vm, &other)?, vm)?, + ), + }, + PySet::class(vm), + None, + )) + } + + fn symmetric_difference(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Ok(PyObject::new( + PySet { + inner: RefCell::new( + self.inner + .borrow() + .symmetric_difference(&get_inner(vm, &other)?, vm)?, + ), + }, + PySet::class(vm), + None, + )) + } + + fn iter(self, vm: &VirtualMachine) -> PyListIterator { + self.inner.borrow().iter(vm) + } + + fn repr(self, vm: &VirtualMachine) -> PyResult { + let inner = self.inner.borrow(); + let s = if inner.len() == 0 { + format!("set()") + } else if let Some(_guard) = ReprGuard::enter(self.as_object()) { + inner.repr(vm)? } else { - other - } - }; + format!("set(...)") + }; + Ok(vm.new_str(s)) + } - let zelf_elements = get_elements(get_zelf(swap)); - let other_elements = get_elements(get_other(swap)); - if size_func(zelf_elements.len(), other_elements.len()) { - return Ok(vm.new_bool(false)); + fn add(self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow_mut().add(&item, vm) } - for element in other_elements.iter() { - match vm.call_method(get_zelf(swap), "__contains__", vec![element.1.clone()]) { - Ok(value) => { - if !objbool::get_value(&value) { - return Ok(vm.new_bool(false)); - } - } - Err(_) => return Err(vm.new_type_error("".to_string())), - } + + fn remove(self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow_mut().remove(&item, vm) } - Ok(vm.new_bool(true)) -} -fn set_union(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - validate_set_or_frozenset(vm, zelf.class())?; - validate_set_or_frozenset(vm, other.class())?; + fn discard(self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow_mut().discard(&item, vm) + } - let mut elements = get_elements(&zelf).clone(); - elements.extend(get_elements(&other).clone()); + fn clear(self, vm: &VirtualMachine) -> PyResult { + self.inner.borrow_mut().clear(vm) + } - create_set(vm, elements, zelf.class()) -} + fn pop(self, vm: &VirtualMachine) -> PyResult { + self.inner.borrow_mut().pop(vm) + } -fn set_intersection(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - set_combine_inner(zelf, other, vm, SetCombineOperation::Intersection) -} + fn ior(self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow_mut().update(iterable, vm)?; + Ok(self.as_object().clone()) + } -fn set_difference(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - set_combine_inner(zelf, other, vm, SetCombineOperation::Difference) -} + fn update(self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow_mut().update(iterable, vm)?; + Ok(vm.get_none()) + } -fn set_symmetric_difference( - zelf: PyObjectRef, - other: PyObjectRef, - vm: &VirtualMachine, -) -> PyResult { - validate_set_or_frozenset(vm, zelf.class())?; - validate_set_or_frozenset(vm, other.class())?; - let mut elements = HashMap::new(); - - for element in get_elements(&zelf).iter() { - let value = vm.call_method(&other, "__contains__", vec![element.1.clone()])?; - if !objbool::get_value(&value) { - elements.insert(element.0.clone(), element.1.clone()); - } + fn intersection_update(self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow_mut().combine_update_inner( + &iterable, + vm, + SetCombineOperation::Intersection, + )?; + Ok(vm.get_none()) } - for element in get_elements(&other).iter() { - let value = vm.call_method(&zelf, "__contains__", vec![element.1.clone()])?; - if !objbool::get_value(&value) { - elements.insert(element.0.clone(), element.1.clone()); - } + fn iand(self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow_mut().combine_update_inner( + &iterable, + vm, + SetCombineOperation::Intersection, + )?; + Ok(self.as_object().clone()) } - create_set(vm, elements, zelf.class()) -} + fn difference_update(self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow_mut().combine_update_inner( + &iterable, + vm, + SetCombineOperation::Difference, + )?; + Ok(vm.get_none()) + } -enum SetCombineOperation { - Intersection, - Difference, + fn isub(self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow_mut().combine_update_inner( + &iterable, + vm, + SetCombineOperation::Difference, + )?; + Ok(self.as_object().clone()) + } + + fn symmetric_difference_update(self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow_mut().ixor(iterable, vm)?; + Ok(vm.get_none()) + } + + fn ixor(self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow_mut().ixor(iterable, vm)?; + Ok(self.as_object().clone()) + } } -fn set_combine_inner( - zelf: PyObjectRef, - other: PyObjectRef, - vm: &VirtualMachine, - op: SetCombineOperation, -) -> PyResult { - validate_set_or_frozenset(vm, zelf.class())?; - validate_set_or_frozenset(vm, other.class())?; - let mut elements = HashMap::new(); - - for element in get_elements(&zelf).iter() { - let value = vm.call_method(&other, "__contains__", vec![element.1.clone()])?; - let should_add = match op { - SetCombineOperation::Intersection => objbool::get_value(&value), - SetCombineOperation::Difference => !objbool::get_value(&value), - }; - if should_add { - elements.insert(element.0.clone(), element.1.clone()); +impl PyFrozenSetRef { + fn len(self, _vm: &VirtualMachine) -> usize { + self.inner.len() + } + fn copy(self, _vm: &VirtualMachine) -> PyFrozenSet { + PyFrozenSet { + inner: self.inner.copy(), } } + fn contains(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.contains(needle, vm) + } - create_set(vm, elements, zelf.class()) -} + fn eq(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.eq(&get_inner(vm, &other)?, vm) + } -fn set_pop(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(s, Some(vm.ctx.set_type()))]); + fn ge(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.ge(&get_inner(vm, &other)?, vm) + } - match s.payload::() { - Some(set) => { - let mut elements = set.elements.borrow_mut(); - match elements.clone().keys().next() { - Some(key) => Ok(elements.remove(key).unwrap()), - None => Err(vm.new_key_error("pop from an empty set".to_string())), - } - } - _ => Err(vm.new_type_error("".to_string())), + fn gt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.gt(&get_inner(vm, &other)?, vm) } -} -fn set_update(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - set_ior(vm, args)?; - Ok(vm.get_none()) -} + fn le(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.le(&get_inner(vm, &other)?, vm) + } -fn set_ior(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(zelf, Some(vm.ctx.set_type())), (iterable, None)] - ); + fn lt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.lt(&get_inner(vm, &other)?, vm) + } - match zelf.payload::() { - Some(set) => { - let iterator = objiter::get_iter(vm, iterable)?; - while let Ok(v) = vm.call_method(&iterator, "__next__", vec![]) { - insert_into_set(vm, &mut set.elements.borrow_mut(), &v)?; - } - } - _ => return Err(vm.new_type_error("set.update is called with no other".to_string())), + fn union(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Ok(PyObject::new( + PyFrozenSet { + inner: self.inner.union(&get_inner(vm, &other)?, vm)?, + }, + PyFrozenSet::class(vm), + None, + )) } - Ok(zelf.clone()) -} -fn set_intersection_update(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - set_combine_update_inner(vm, args, SetCombineOperation::Intersection)?; - Ok(vm.get_none()) -} + fn intersection(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Ok(PyObject::new( + PyFrozenSet { + inner: self.inner.intersection(&get_inner(vm, &other)?, vm)?, + }, + PyFrozenSet::class(vm), + None, + )) + } -fn set_iand(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - set_combine_update_inner(vm, args, SetCombineOperation::Intersection) + fn difference(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Ok(PyObject::new( + PyFrozenSet { + inner: self.inner.difference(&get_inner(vm, &other)?, vm)?, + }, + PyFrozenSet::class(vm), + None, + )) + } + + fn symmetric_difference(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Ok(PyObject::new( + PyFrozenSet { + inner: self + .inner + .symmetric_difference(&get_inner(vm, &other)?, vm)?, + }, + PyFrozenSet::class(vm), + None, + )) + } + + fn iter(self, vm: &VirtualMachine) -> PyListIterator { + self.inner.iter(vm) + } + + fn repr(self, vm: &VirtualMachine) -> PyResult { + let inner = &self.inner; + let s = if inner.len() == 0 { + format!("frozenset()") + } else if let Some(_guard) = ReprGuard::enter(self.as_object()) { + format!("frozenset({})", inner.repr(vm)?) + } else { + format!("frozenset(...)") + }; + Ok(vm.new_str(s)) + } } -fn set_difference_update(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - set_combine_update_inner(vm, args, SetCombineOperation::Difference)?; - Ok(vm.get_none()) +fn get_inner(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResult { + if let Some(set) = obj.payload::() { + Ok(set.inner.borrow().clone()) + } else if let Some(frozenset) = obj.payload::() { + Ok(frozenset.inner.clone()) + } else { + Err(vm.new_type_error(format!( + "{} is not a subtype of set or frozenset", + obj.class() + ))) + } } -fn set_isub(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - set_combine_update_inner(vm, args, SetCombineOperation::Difference) +fn validate_set_or_frozenset(vm: &VirtualMachine, cls: PyClassRef) -> PyResult<()> { + if !(objtype::issubclass(&cls, &vm.ctx.set_type()) + || objtype::issubclass(&cls, &vm.ctx.frozenset_type())) + { + return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", cls))); + } + Ok(()) } -fn set_combine_update_inner( +fn create_set( vm: &VirtualMachine, - args: PyFuncArgs, - op: SetCombineOperation, + elements: HashMap, + cls: PyClassRef, ) -> PyResult { - arg_check!( - vm, - args, - required = [(zelf, Some(vm.ctx.set_type())), (iterable, None)] - ); - - match zelf.payload::() { - Some(set) => { - let mut elements = set.elements.borrow_mut(); - for element in elements.clone().iter() { - let value = vm.call_method(iterable, "__contains__", vec![element.1.clone()])?; - let should_remove = match op { - SetCombineOperation::Intersection => !objbool::get_value(&value), - SetCombineOperation::Difference => objbool::get_value(&value), - }; - if should_remove { - elements.remove(&element.0.clone()); - } - } - } - _ => return Err(vm.new_type_error("".to_string())), + if objtype::issubclass(&cls, &vm.ctx.set_type()) { + Ok(PyObject::new( + PySet { + inner: RefCell::new(PySetInner { elements: elements }), + }, + PySet::class(vm), + None, + )) + } else if objtype::issubclass(&cls, &vm.ctx.frozenset_type()) { + Ok(PyObject::new( + PyFrozenSet { + inner: PySetInner { elements: elements }, + }, + PyFrozenSet::class(vm), + None, + )) + } else { + Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", cls))) } - Ok(zelf.clone()) } -fn set_symmetric_difference_update(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - set_ixor(vm, args)?; - Ok(vm.get_none()) -} - -fn set_ixor(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(zelf, Some(vm.ctx.set_type())), (iterable, None)] - ); - - match zelf.payload::() { - Some(set) => { - let elements_original = set.elements.borrow().clone(); - let iterator = objiter::get_iter(vm, iterable)?; - while let Ok(v) = vm.call_method(&iterator, "__next__", vec![]) { - insert_into_set(vm, &mut set.elements.borrow_mut(), &v)?; - } - for element in elements_original.iter() { - let value = vm.call_method(iterable, "__contains__", vec![element.1.clone()])?; - if objbool::get_value(&value) { - set.elements.borrow_mut().remove(&element.0.clone()); - } - } - } - _ => return Err(vm.new_type_error("".to_string())), - } +fn perform_action_with_hash( + vm: &VirtualMachine, + elements: &mut HashMap, + item: &PyObjectRef, + f: &Fn(&VirtualMachine, &mut HashMap, u64, &PyObjectRef) -> PyResult, +) -> PyResult { + let hash: PyObjectRef = vm.call_method(item, "__hash__", vec![])?; - Ok(zelf.clone()) + let hash_value = objint::get_value(&hash); + let mut hasher = DefaultHasher::new(); + hash_value.hash(&mut hasher); + let key = hasher.finish(); + f(vm, elements, key, item) } -fn set_iter(zelf: PySetRef, vm: &VirtualMachine) -> PyListIterator { - // TODO: separate type - let items = zelf.elements.borrow().values().cloned().collect(); - let set_list = vm.ctx.new_list(items); - PyListIterator { - position: Cell::new(0), - list: set_list.downcast().unwrap(), +fn insert_into_set( + vm: &VirtualMachine, + elements: &mut HashMap, + item: &PyObjectRef, +) -> PyResult { + fn insert( + vm: &VirtualMachine, + elements: &mut HashMap, + key: u64, + value: &PyObjectRef, + ) -> PyResult { + elements.insert(key, value.clone()); + Ok(vm.get_none()) } + perform_action_with_hash(vm, elements, item, &insert) } -fn frozenset_repr(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(o, Some(vm.ctx.frozenset_type()))]); +/* Create a new object of sub-type of set */ +fn set_new(cls: PyClassRef, iterable: OptionalArg, vm: &VirtualMachine) -> PyResult { + validate_set_or_frozenset(vm, cls.clone())?; - let elements = get_elements(o); - let s = if elements.is_empty() { - "frozenset()".to_string() - } else { - let mut str_parts = vec![]; - for elem in elements.values() { - let part = vm.to_repr(elem)?; - str_parts.push(part.value.clone()); + let elements: HashMap = match iterable { + OptionalArg::Missing => HashMap::new(), + OptionalArg::Present(iterable) => { + let mut elements = HashMap::new(); + let iterator = objiter::get_iter(vm, &iterable)?; + while let Ok(v) = vm.call_method(&iterator, "__next__", vec![]) { + insert_into_set(vm, &mut elements, &v)?; + } + elements } - - format!("frozenset({{{}}})", str_parts.join(", ")) }; - Ok(vm.new_str(s)) + + create_set(vm, elements, cls.clone()) +} + +enum SetCombineOperation { + Intersection, + Difference, } fn set_hash(_zelf: PySetRef, vm: &VirtualMachine) -> PyResult { @@ -600,42 +719,42 @@ pub fn init(context: &PyContext) { Build an unordered collection of unique elements."; extend_class!(context, set_type, { - "__contains__" => context.new_rustfunc(set_contains), - "__len__" => context.new_rustfunc(set_len), + "__contains__" => context.new_rustfunc(PySetRef::contains), + "__len__" => context.new_rustfunc(PySetRef::len), "__new__" => context.new_rustfunc(set_new), - "__repr__" => context.new_rustfunc(set_repr), "__hash__" => context.new_rustfunc(set_hash), - "__eq__" => context.new_rustfunc(set_eq), - "__ge__" => context.new_rustfunc(set_ge), - "__gt__" => context.new_rustfunc(set_gt), - "__le__" => context.new_rustfunc(set_le), - "__lt__" => context.new_rustfunc(set_lt), - "issubset" => context.new_rustfunc(set_le), - "issuperset" => context.new_rustfunc(set_ge), - "union" => context.new_rustfunc(set_union), - "__or__" => context.new_rustfunc(set_union), - "intersection" => context.new_rustfunc(set_intersection), - "__and__" => context.new_rustfunc(set_intersection), - "difference" => context.new_rustfunc(set_difference), - "__sub__" => context.new_rustfunc(set_difference), - "symmetric_difference" => context.new_rustfunc(set_symmetric_difference), - "__xor__" => context.new_rustfunc(set_symmetric_difference), + "__repr__" => context.new_rustfunc(PySetRef::repr), + "__eq__" => context.new_rustfunc(PySetRef::eq), + "__ge__" => context.new_rustfunc(PySetRef::ge), + "__gt__" => context.new_rustfunc(PySetRef::gt), + "__le__" => context.new_rustfunc(PySetRef::le), + "__lt__" => context.new_rustfunc(PySetRef::lt), + "issubset" => context.new_rustfunc(PySetRef::le), + "issuperset" => context.new_rustfunc(PySetRef::ge), + "union" => context.new_rustfunc(PySetRef::union), + "__or__" => context.new_rustfunc(PySetRef::union), + "intersection" => context.new_rustfunc(PySetRef::intersection), + "__and__" => context.new_rustfunc(PySetRef::intersection), + "difference" => context.new_rustfunc(PySetRef::difference), + "__sub__" => context.new_rustfunc(PySetRef::difference), + "symmetric_difference" => context.new_rustfunc(PySetRef::symmetric_difference), + "__xor__" => context.new_rustfunc(PySetRef::symmetric_difference), "__doc__" => context.new_str(set_doc.to_string()), - "add" => context.new_rustfunc(set_add), - "remove" => context.new_rustfunc(set_remove), - "discard" => context.new_rustfunc(set_discard), - "clear" => context.new_rustfunc(set_clear), - "copy" => context.new_rustfunc(set_copy), - "pop" => context.new_rustfunc(set_pop), - "update" => context.new_rustfunc(set_update), - "__ior__" => context.new_rustfunc(set_ior), - "intersection_update" => context.new_rustfunc(set_intersection_update), - "__iand__" => context.new_rustfunc(set_iand), - "difference_update" => context.new_rustfunc(set_difference_update), - "__isub__" => context.new_rustfunc(set_isub), - "symmetric_difference_update" => context.new_rustfunc(set_symmetric_difference_update), - "__ixor__" => context.new_rustfunc(set_ixor), - "__iter__" => context.new_rustfunc(set_iter) + "add" => context.new_rustfunc(PySetRef::add), + "remove" => context.new_rustfunc(PySetRef::remove), + "discard" => context.new_rustfunc(PySetRef::discard), + "clear" => context.new_rustfunc(PySetRef::clear), + "copy" => context.new_rustfunc(PySetRef::copy), + "pop" => context.new_rustfunc(PySetRef::pop), + "update" => context.new_rustfunc(PySetRef::update), + "__ior__" => context.new_rustfunc(PySetRef::ior), + "intersection_update" => context.new_rustfunc(PySetRef::intersection_update), + "__iand__" => context.new_rustfunc(PySetRef::iand), + "difference_update" => context.new_rustfunc(PySetRef::difference_update), + "__isub__" => context.new_rustfunc(PySetRef::isub), + "symmetric_difference_update" => context.new_rustfunc(PySetRef::symmetric_difference_update), + "__ixor__" => context.new_rustfunc(PySetRef::ixor), + "__iter__" => context.new_rustfunc(PySetRef::iter) }); let frozenset_type = &context.frozenset_type; @@ -646,25 +765,26 @@ pub fn init(context: &PyContext) { extend_class!(context, frozenset_type, { "__new__" => context.new_rustfunc(set_new), - "__eq__" => context.new_rustfunc(set_eq), - "__ge__" => context.new_rustfunc(set_ge), - "__gt__" => context.new_rustfunc(set_gt), - "__le__" => context.new_rustfunc(set_le), - "__lt__" => context.new_rustfunc(set_lt), - "issubset" => context.new_rustfunc(set_le), - "issuperset" => context.new_rustfunc(set_ge), - "union" => context.new_rustfunc(set_union), - "__or__" => context.new_rustfunc(set_union), - "intersection" => context.new_rustfunc(set_intersection), - "__and__" => context.new_rustfunc(set_intersection), - "difference" => context.new_rustfunc(set_difference), - "__sub__" => context.new_rustfunc(set_difference), - "symmetric_difference" => context.new_rustfunc(set_symmetric_difference), - "__xor__" => context.new_rustfunc(set_symmetric_difference), - "__contains__" => context.new_rustfunc(set_contains), - "__len__" => context.new_rustfunc(set_len), + "__eq__" => context.new_rustfunc(PyFrozenSetRef::eq), + "__ge__" => context.new_rustfunc(PyFrozenSetRef::ge), + "__gt__" => context.new_rustfunc(PyFrozenSetRef::gt), + "__le__" => context.new_rustfunc(PyFrozenSetRef::le), + "__lt__" => context.new_rustfunc(PyFrozenSetRef::lt), + "issubset" => context.new_rustfunc(PyFrozenSetRef::le), + "issuperset" => context.new_rustfunc(PyFrozenSetRef::ge), + "union" => context.new_rustfunc(PyFrozenSetRef::union), + "__or__" => context.new_rustfunc(PyFrozenSetRef::union), + "intersection" => context.new_rustfunc(PyFrozenSetRef::intersection), + "__and__" => context.new_rustfunc(PyFrozenSetRef::intersection), + "difference" => context.new_rustfunc(PyFrozenSetRef::difference), + "__sub__" => context.new_rustfunc(PyFrozenSetRef::difference), + "symmetric_difference" => context.new_rustfunc(PyFrozenSetRef::symmetric_difference), + "__xor__" => context.new_rustfunc(PyFrozenSetRef::symmetric_difference), + "__contains__" => context.new_rustfunc(PyFrozenSetRef::contains), + "__len__" => context.new_rustfunc(PyFrozenSetRef::len), "__doc__" => context.new_str(frozenset_doc.to_string()), - "__repr__" => context.new_rustfunc(frozenset_repr), - "copy" => context.new_rustfunc(set_copy) + "__repr__" => context.new_rustfunc(PyFrozenSetRef::repr), + "copy" => context.new_rustfunc(PyFrozenSetRef::copy), + "__iter__" => context.new_rustfunc(PyFrozenSetRef::iter) }); } From a685216f29ecc75e2bc3f9637a42e5fe878c5b37 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Thu, 4 Apr 2019 20:47:29 +0300 Subject: [PATCH 02/11] Use new arg style for set and frozenset new --- vm/src/obj/objset.rs | 103 +++++++++++++++++++------------------------ 1 file changed, 45 insertions(+), 58 deletions(-) diff --git a/vm/src/obj/objset.rs b/vm/src/obj/objset.rs index eaa0cd09ca..fd31044aa4 100644 --- a/vm/src/obj/objset.rs +++ b/vm/src/obj/objset.rs @@ -15,7 +15,6 @@ use super::objbool; use super::objint; use super::objiter; use super::objlist::PyListIterator; -use super::objtype; use super::objtype::PyClassRef; #[derive(Default)] @@ -62,6 +61,22 @@ struct PySetInner { } impl PySetInner { + fn new(iterable: OptionalArg, vm: &VirtualMachine) -> PyResult { + let elements: HashMap = match iterable { + OptionalArg::Missing => HashMap::new(), + OptionalArg::Present(iterable) => { + let mut elements = HashMap::new(); + let iterator = objiter::get_iter(vm, &iterable)?; + while let Ok(v) = vm.call_method(&iterator, "__next__", vec![]) { + insert_into_set(vm, &mut elements, &v)?; + } + elements + } + }; + + Ok(PySetInner { elements }) + } + fn len(&self) -> usize { self.elements.len() } @@ -334,14 +349,27 @@ impl PySetInner { } impl PySetRef { + fn new( + cls: PyClassRef, + iterable: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + PySet { + inner: RefCell::new(PySetInner::new(iterable, vm)?), + } + .into_ref_with_type(vm, cls) + } + fn len(self, _vm: &VirtualMachine) -> usize { self.inner.borrow().len() } + fn copy(self, _vm: &VirtualMachine) -> PySet { PySet { inner: RefCell::new(self.inner.borrow().copy()), } } + fn contains(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { self.inner.borrow().contains(needle, vm) } @@ -512,14 +540,27 @@ impl PySetRef { } impl PyFrozenSetRef { + fn new( + cls: PyClassRef, + iterable: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + PyFrozenSet { + inner: PySetInner::new(iterable, vm)?, + } + .into_ref_with_type(vm, cls) + } + fn len(self, _vm: &VirtualMachine) -> usize { self.inner.len() } + fn copy(self, _vm: &VirtualMachine) -> PyFrozenSet { PyFrozenSet { inner: self.inner.copy(), } } + fn contains(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { self.inner.contains(needle, vm) } @@ -616,41 +657,6 @@ fn get_inner(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResult { } } -fn validate_set_or_frozenset(vm: &VirtualMachine, cls: PyClassRef) -> PyResult<()> { - if !(objtype::issubclass(&cls, &vm.ctx.set_type()) - || objtype::issubclass(&cls, &vm.ctx.frozenset_type())) - { - return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", cls))); - } - Ok(()) -} - -fn create_set( - vm: &VirtualMachine, - elements: HashMap, - cls: PyClassRef, -) -> PyResult { - if objtype::issubclass(&cls, &vm.ctx.set_type()) { - Ok(PyObject::new( - PySet { - inner: RefCell::new(PySetInner { elements: elements }), - }, - PySet::class(vm), - None, - )) - } else if objtype::issubclass(&cls, &vm.ctx.frozenset_type()) { - Ok(PyObject::new( - PyFrozenSet { - inner: PySetInner { elements: elements }, - }, - PyFrozenSet::class(vm), - None, - )) - } else { - Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", cls))) - } -} - fn perform_action_with_hash( vm: &VirtualMachine, elements: &mut HashMap, @@ -683,25 +689,6 @@ fn insert_into_set( perform_action_with_hash(vm, elements, item, &insert) } -/* Create a new object of sub-type of set */ -fn set_new(cls: PyClassRef, iterable: OptionalArg, vm: &VirtualMachine) -> PyResult { - validate_set_or_frozenset(vm, cls.clone())?; - - let elements: HashMap = match iterable { - OptionalArg::Missing => HashMap::new(), - OptionalArg::Present(iterable) => { - let mut elements = HashMap::new(); - let iterator = objiter::get_iter(vm, &iterable)?; - while let Ok(v) = vm.call_method(&iterator, "__next__", vec![]) { - insert_into_set(vm, &mut elements, &v)?; - } - elements - } - }; - - create_set(vm, elements, cls.clone()) -} - enum SetCombineOperation { Intersection, Difference, @@ -719,10 +706,10 @@ pub fn init(context: &PyContext) { Build an unordered collection of unique elements."; extend_class!(context, set_type, { + "__hash__" => context.new_rustfunc(set_hash), "__contains__" => context.new_rustfunc(PySetRef::contains), "__len__" => context.new_rustfunc(PySetRef::len), - "__new__" => context.new_rustfunc(set_new), - "__hash__" => context.new_rustfunc(set_hash), + "__new__" => context.new_rustfunc(PySetRef::new), "__repr__" => context.new_rustfunc(PySetRef::repr), "__eq__" => context.new_rustfunc(PySetRef::eq), "__ge__" => context.new_rustfunc(PySetRef::ge), @@ -764,7 +751,7 @@ pub fn init(context: &PyContext) { Build an immutable unordered collection of unique elements."; extend_class!(context, frozenset_type, { - "__new__" => context.new_rustfunc(set_new), + "__new__" => context.new_rustfunc(PyFrozenSetRef::new), "__eq__" => context.new_rustfunc(PyFrozenSetRef::eq), "__ge__" => context.new_rustfunc(PyFrozenSetRef::ge), "__gt__" => context.new_rustfunc(PyFrozenSetRef::gt), From 2a3a88c2f2cf0f13f0ed89843f19680f43b8d403 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Thu, 4 Apr 2019 21:23:00 +0300 Subject: [PATCH 03/11] Use match_class macro to avoid unwanted clones --- vm/src/obj/objset.rs | 135 ++++++++++++++++++++++++++++--------------- 1 file changed, 90 insertions(+), 45 deletions(-) diff --git a/vm/src/obj/objset.rs b/vm/src/obj/objset.rs index fd31044aa4..6c67312fb1 100644 --- a/vm/src/obj/objset.rs +++ b/vm/src/obj/objset.rs @@ -375,29 +375,53 @@ impl PySetRef { } fn eq(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.inner.borrow().eq(&get_inner(vm, &other)?, vm) + match_class!(other, + set @ PySet => self.inner.borrow().eq(&set.inner.borrow(), vm), + frozen @ PyFrozenSet => self.inner.borrow().eq(&frozen.inner, vm), + other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, + ) } fn ge(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.inner.borrow().ge(&get_inner(vm, &other)?, vm) + match_class!(other, + set @ PySet => self.inner.borrow().ge(&set.inner.borrow(), vm), + frozen @ PyFrozenSet => self.inner.borrow().ge(&frozen.inner, vm), + other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, + ) } fn gt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.inner.borrow().gt(&get_inner(vm, &other)?, vm) + match_class!(other, + set @ PySet => self.inner.borrow().gt(&set.inner.borrow(), vm), + frozen @ PyFrozenSet => self.inner.borrow().gt(&frozen.inner, vm), + other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, + ) } fn le(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.inner.borrow().le(&get_inner(vm, &other)?, vm) + match_class!(other, + set @ PySet => self.inner.borrow().le(&set.inner.borrow(), vm), + frozen @ PyFrozenSet => self.inner.borrow().le(&frozen.inner, vm), + other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, + ) } fn lt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.inner.borrow().lt(&get_inner(vm, &other)?, vm) + match_class!(other, + set @ PySet => self.inner.borrow().lt(&set.inner.borrow(), vm), + frozen @ PyFrozenSet => self.inner.borrow().lt(&frozen.inner, vm), + other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, + ) } fn union(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { Ok(PyObject::new( PySet { - inner: RefCell::new(self.inner.borrow().union(&get_inner(vm, &other)?, vm)?), + inner: RefCell::new(match_class!(other, + set @ PySet => self.inner.borrow().union(&set.inner.borrow(), vm)?, + frozen @ PyFrozenSet => self.inner.borrow().union(&frozen.inner, vm)?, + other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, + )), }, PySet::class(vm), None, @@ -407,11 +431,11 @@ impl PySetRef { fn intersection(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { Ok(PyObject::new( PySet { - inner: RefCell::new( - self.inner - .borrow() - .intersection(&get_inner(vm, &other)?, vm)?, - ), + inner: RefCell::new(match_class!(other, + set @ PySet => self.inner.borrow().intersection(&set.inner.borrow(), vm)?, + frozen @ PyFrozenSet => self.inner.borrow().intersection(&frozen.inner, vm)?, + other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, + )), }, PySet::class(vm), None, @@ -421,11 +445,11 @@ impl PySetRef { fn difference(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { Ok(PyObject::new( PySet { - inner: RefCell::new( - self.inner - .borrow() - .difference(&get_inner(vm, &other)?, vm)?, - ), + inner: RefCell::new(match_class!(other, + set @ PySet => self.inner.borrow().difference(&set.inner.borrow(), vm)?, + frozen @ PyFrozenSet => self.inner.borrow().difference(&frozen.inner, vm)?, + other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, + )), }, PySet::class(vm), None, @@ -435,11 +459,11 @@ impl PySetRef { fn symmetric_difference(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { Ok(PyObject::new( PySet { - inner: RefCell::new( - self.inner - .borrow() - .symmetric_difference(&get_inner(vm, &other)?, vm)?, - ), + inner: RefCell::new(match_class!(other, + set @ PySet => self.inner.borrow().symmetric_difference(&set.inner.borrow(), vm)?, + frozen @ PyFrozenSet => self.inner.borrow().symmetric_difference(&frozen.inner, vm)?, + other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, + )), }, PySet::class(vm), None, @@ -566,29 +590,53 @@ impl PyFrozenSetRef { } fn eq(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.inner.eq(&get_inner(vm, &other)?, vm) + match_class!(other, + set @ PySet => self.inner.eq(&set.inner.borrow(), vm), + frozen @ PyFrozenSet => self.inner.eq(&frozen.inner, vm), + other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, + ) } fn ge(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.inner.ge(&get_inner(vm, &other)?, vm) + match_class!(other, + set @ PySet => self.inner.ge(&set.inner.borrow(), vm), + frozen @ PyFrozenSet => self.inner.ge(&frozen.inner, vm), + other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, + ) } fn gt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.inner.gt(&get_inner(vm, &other)?, vm) + match_class!(other, + set @ PySet => self.inner.gt(&set.inner.borrow(), vm), + frozen @ PyFrozenSet => self.inner.gt(&frozen.inner, vm), + other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, + ) } fn le(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.inner.le(&get_inner(vm, &other)?, vm) + match_class!(other, + set @ PySet => self.inner.le(&set.inner.borrow(), vm), + frozen @ PyFrozenSet => self.inner.le(&frozen.inner, vm), + other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, + ) } fn lt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.inner.lt(&get_inner(vm, &other)?, vm) + match_class!(other, + set @ PySet => self.inner.lt(&set.inner.borrow(), vm), + frozen @ PyFrozenSet => self.inner.lt(&frozen.inner, vm), + other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, + ) } fn union(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { Ok(PyObject::new( PyFrozenSet { - inner: self.inner.union(&get_inner(vm, &other)?, vm)?, + inner: match_class!(other, + set @ PySet => self.inner.union(&set.inner.borrow(), vm)?, + frozen @ PyFrozenSet => self.inner.union(&frozen.inner, vm)?, + other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, + ), }, PyFrozenSet::class(vm), None, @@ -598,7 +646,11 @@ impl PyFrozenSetRef { fn intersection(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { Ok(PyObject::new( PyFrozenSet { - inner: self.inner.intersection(&get_inner(vm, &other)?, vm)?, + inner: match_class!(other, + set @ PySet => self.inner.intersection(&set.inner.borrow(), vm)?, + frozen @ PyFrozenSet => self.inner.intersection(&frozen.inner, vm)?, + other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, + ), }, PyFrozenSet::class(vm), None, @@ -608,7 +660,11 @@ impl PyFrozenSetRef { fn difference(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { Ok(PyObject::new( PyFrozenSet { - inner: self.inner.difference(&get_inner(vm, &other)?, vm)?, + inner: match_class!(other, + set @ PySet => self.inner.difference(&set.inner.borrow(), vm)?, + frozen @ PyFrozenSet => self.inner.difference(&frozen.inner, vm)?, + other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, + ), }, PyFrozenSet::class(vm), None, @@ -618,9 +674,11 @@ impl PyFrozenSetRef { fn symmetric_difference(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { Ok(PyObject::new( PyFrozenSet { - inner: self - .inner - .symmetric_difference(&get_inner(vm, &other)?, vm)?, + inner: match_class!(other, + set @ PySet => self.inner.symmetric_difference(&set.inner.borrow(), vm)?, + frozen @ PyFrozenSet => self.inner.symmetric_difference(&frozen.inner, vm)?, + other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, + ), }, PyFrozenSet::class(vm), None, @@ -644,19 +702,6 @@ impl PyFrozenSetRef { } } -fn get_inner(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResult { - if let Some(set) = obj.payload::() { - Ok(set.inner.borrow().clone()) - } else if let Some(frozenset) = obj.payload::() { - Ok(frozenset.inner.clone()) - } else { - Err(vm.new_type_error(format!( - "{} is not a subtype of set or frozenset", - obj.class() - ))) - } -} - fn perform_action_with_hash( vm: &VirtualMachine, elements: &mut HashMap, From a0aa88d2fb3aba561e0a085a14f618f7008f1f92 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sun, 7 Apr 2019 20:02:27 +0300 Subject: [PATCH 04/11] clear return implicit None --- vm/src/obj/objset.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vm/src/obj/objset.rs b/vm/src/obj/objset.rs index 6c67312fb1..c3c580bc4b 100644 --- a/vm/src/obj/objset.rs +++ b/vm/src/obj/objset.rs @@ -290,9 +290,8 @@ impl PySetInner { perform_action_with_hash(vm, &mut self.elements, &item, &discard) } - fn clear(&mut self, vm: &VirtualMachine) -> PyResult { + fn clear(&mut self) -> () { self.elements.clear(); - Ok(vm.get_none()) } fn pop(&mut self, vm: &VirtualMachine) -> PyResult { @@ -498,8 +497,8 @@ impl PySetRef { self.inner.borrow_mut().discard(&item, vm) } - fn clear(self, vm: &VirtualMachine) -> PyResult { - self.inner.borrow_mut().clear(vm) + fn clear(self, _vm: &VirtualMachine) -> () { + self.inner.borrow_mut().clear() } fn pop(self, vm: &VirtualMachine) -> PyResult { From 3bc1e3598c9999a4f1aca5ea72a3d4f749702317 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sun, 7 Apr 2019 21:26:20 +0300 Subject: [PATCH 05/11] Use PyIterable --- tests/snippets/set.py | 2 ++ vm/src/obj/objset.rs | 48 +++++++++++++++++++------------------------ 2 files changed, 23 insertions(+), 27 deletions(-) diff --git a/tests/snippets/set.py b/tests/snippets/set.py index 0117163350..11d62944a3 100644 --- a/tests/snippets/set.py +++ b/tests/snippets/set.py @@ -49,6 +49,7 @@ def __hash__(self): assert set([1,2,3]).union(set([4,5])) == set([1,2,3,4,5]) assert set([1,2,3]).union(set([1,2,3,4,5])) == set([1,2,3,4,5]) +assert set([1,2,3]).union([1,2,3,4,5]) == set([1,2,3,4,5]) assert set([1,2,3]) | set([4,5]) == set([1,2,3,4,5]) assert set([1,2,3]) | set([1,2,3,4,5]) == set([1,2,3,4,5]) @@ -181,6 +182,7 @@ def __hash__(self): assert frozenset([1,2,3]).union(frozenset([4,5])) == frozenset([1,2,3,4,5]) assert frozenset([1,2,3]).union(frozenset([1,2,3,4,5])) == frozenset([1,2,3,4,5]) +assert frozenset([1,2,3]).union([1,2,3,4,5]) == frozenset([1,2,3,4,5]) assert frozenset([1,2,3]) | frozenset([4,5]) == frozenset([1,2,3,4,5]) assert frozenset([1,2,3]) | frozenset([1,2,3,4,5]) == frozenset([1,2,3,4,5]) diff --git a/vm/src/obj/objset.rs b/vm/src/obj/objset.rs index c3c580bc4b..46ba1db664 100644 --- a/vm/src/obj/objset.rs +++ b/vm/src/obj/objset.rs @@ -8,7 +8,9 @@ use std::fmt; use std::hash::{Hash, Hasher}; use crate::function::OptionalArg; -use crate::pyobject::{PyContext, PyObject, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol}; +use crate::pyobject::{ + PyContext, PyIterable, PyObject, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, +}; use crate::vm::{ReprGuard, VirtualMachine}; use super::objbool; @@ -61,14 +63,13 @@ struct PySetInner { } impl PySetInner { - fn new(iterable: OptionalArg, vm: &VirtualMachine) -> PyResult { + fn new(iterable: OptionalArg, vm: &VirtualMachine) -> PyResult { let elements: HashMap = match iterable { OptionalArg::Missing => HashMap::new(), OptionalArg::Present(iterable) => { let mut elements = HashMap::new(); - let iterator = objiter::get_iter(vm, &iterable)?; - while let Ok(v) = vm.call_method(&iterator, "__next__", vec![]) { - insert_into_set(vm, &mut elements, &v)?; + for item in iterable.iter(vm)? { + insert_into_set(vm, &mut elements, &item?)?; } elements } @@ -175,9 +176,11 @@ impl PySetInner { ) } - fn union(&self, other: &PySetInner, _vm: &VirtualMachine) -> PyResult { + fn union(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult { let mut elements = self.elements.clone(); - elements.extend(other.elements.clone()); + for item in other.iter(vm)? { + insert_into_set(vm, &mut elements, &item?)?; + } Ok(PySetInner { elements }) } @@ -302,10 +305,9 @@ impl PySetInner { } } - fn update(&mut self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let iterator = objiter::get_iter(vm, &iterable)?; - while let Ok(v) = vm.call_method(&iterator, "__next__", vec![]) { - insert_into_set(vm, &mut self.elements, &v)?; + fn update(&mut self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult { + for item in iterable.iter(vm)? { + insert_into_set(vm, &mut self.elements, &item?)?; } Ok(vm.get_none()) } @@ -350,7 +352,7 @@ impl PySetInner { impl PySetRef { fn new( cls: PyClassRef, - iterable: OptionalArg, + iterable: OptionalArg, vm: &VirtualMachine, ) -> PyResult { PySet { @@ -413,14 +415,10 @@ impl PySetRef { ) } - fn union(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn union(self, other: PyIterable, vm: &VirtualMachine) -> PyResult { Ok(PyObject::new( PySet { - inner: RefCell::new(match_class!(other, - set @ PySet => self.inner.borrow().union(&set.inner.borrow(), vm)?, - frozen @ PyFrozenSet => self.inner.borrow().union(&frozen.inner, vm)?, - other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, - )), + inner: RefCell::new(self.inner.borrow().union(other, vm)?), }, PySet::class(vm), None, @@ -505,12 +503,12 @@ impl PySetRef { self.inner.borrow_mut().pop(vm) } - fn ior(self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn ior(self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult { self.inner.borrow_mut().update(iterable, vm)?; Ok(self.as_object().clone()) } - fn update(self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn update(self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult { self.inner.borrow_mut().update(iterable, vm)?; Ok(vm.get_none()) } @@ -565,7 +563,7 @@ impl PySetRef { impl PyFrozenSetRef { fn new( cls: PyClassRef, - iterable: OptionalArg, + iterable: OptionalArg, vm: &VirtualMachine, ) -> PyResult { PyFrozenSet { @@ -628,14 +626,10 @@ impl PyFrozenSetRef { ) } - fn union(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn union(self, other: PyIterable, vm: &VirtualMachine) -> PyResult { Ok(PyObject::new( PyFrozenSet { - inner: match_class!(other, - set @ PySet => self.inner.union(&set.inner.borrow(), vm)?, - frozen @ PyFrozenSet => self.inner.union(&frozen.inner, vm)?, - other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, - ), + inner: self.inner.union(other, vm)?, }, PyFrozenSet::class(vm), None, From fa369ff77964b53d34a40a3a481bc728c2e0be47 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sun, 7 Apr 2019 21:58:37 +0300 Subject: [PATCH 06/11] Limit non-operator versions to set and frozenset --- tests/snippets/set.py | 2 ++ vm/src/obj/objset.rs | 37 ++++++++++++++++++++++++++++++++++--- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/tests/snippets/set.py b/tests/snippets/set.py index 11d62944a3..174426aa2b 100644 --- a/tests/snippets/set.py +++ b/tests/snippets/set.py @@ -53,6 +53,7 @@ def __hash__(self): assert set([1,2,3]) | set([4,5]) == set([1,2,3,4,5]) assert set([1,2,3]) | set([1,2,3,4,5]) == set([1,2,3,4,5]) +assert_raises(TypeError, lambda: set([1,2,3]) | [1,2,3,4,5]) assert set([1,2,3]).intersection(set([1,2])) == set([1,2]) assert set([1,2,3]).intersection(set([5,6])) == set([]) @@ -186,6 +187,7 @@ def __hash__(self): assert frozenset([1,2,3]) | frozenset([4,5]) == frozenset([1,2,3,4,5]) assert frozenset([1,2,3]) | frozenset([1,2,3,4,5]) == frozenset([1,2,3,4,5]) +assert_raises(TypeError, lambda: frozenset([1,2,3]) | [1,2,3,4,5]) assert frozenset([1,2,3]).intersection(frozenset([1,2])) == frozenset([1,2]) assert frozenset([1,2,3]).intersection(frozenset([5,6])) == frozenset([]) diff --git a/vm/src/obj/objset.rs b/vm/src/obj/objset.rs index 46ba1db664..2aa00431bf 100644 --- a/vm/src/obj/objset.rs +++ b/vm/src/obj/objset.rs @@ -9,7 +9,8 @@ use std::hash::{Hash, Hasher}; use crate::function::OptionalArg; use crate::pyobject::{ - PyContext, PyIterable, PyObject, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, + PyContext, PyIterable, PyObject, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, + TypeProtocol, }; use crate::vm::{ReprGuard, VirtualMachine}; @@ -17,6 +18,7 @@ use super::objbool; use super::objint; use super::objiter; use super::objlist::PyListIterator; +use super::objtype; use super::objtype::PyClassRef; #[derive(Default)] @@ -467,6 +469,10 @@ impl PySetRef { )) } + fn or(self, other: SetIterable, vm: &VirtualMachine) -> PyResult { + self.union(other.iterable, vm) + } + fn iter(self, vm: &VirtualMachine) -> PyListIterator { self.inner.borrow().iter(vm) } @@ -636,6 +642,10 @@ impl PyFrozenSetRef { )) } + fn or(self, other: SetIterable, vm: &VirtualMachine) -> PyResult { + self.union(other.iterable, vm) + } + fn intersection(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { Ok(PyObject::new( PyFrozenSet { @@ -732,6 +742,27 @@ enum SetCombineOperation { Difference, } +struct SetIterable { + iterable: PyIterable, +} + +impl TryFromObject for SetIterable { + fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { + if objtype::issubclass(&obj.class(), &vm.ctx.set_type()) + || objtype::issubclass(&obj.class(), &vm.ctx.frozenset_type()) + { + Ok(SetIterable { + iterable: PyIterable::try_from_object(vm, obj)?, + }) + } else { + Err(vm.new_type_error(format!( + "{} is not a subtype of set or frozenset", + obj.class() + ))) + } + } +} + fn set_hash(_zelf: PySetRef, vm: &VirtualMachine) -> PyResult { Err(vm.new_type_error("unhashable type".to_string())) } @@ -757,7 +788,7 @@ pub fn init(context: &PyContext) { "issubset" => context.new_rustfunc(PySetRef::le), "issuperset" => context.new_rustfunc(PySetRef::ge), "union" => context.new_rustfunc(PySetRef::union), - "__or__" => context.new_rustfunc(PySetRef::union), + "__or__" => context.new_rustfunc(PySetRef::or), "intersection" => context.new_rustfunc(PySetRef::intersection), "__and__" => context.new_rustfunc(PySetRef::intersection), "difference" => context.new_rustfunc(PySetRef::difference), @@ -798,7 +829,7 @@ pub fn init(context: &PyContext) { "issubset" => context.new_rustfunc(PyFrozenSetRef::le), "issuperset" => context.new_rustfunc(PyFrozenSetRef::ge), "union" => context.new_rustfunc(PyFrozenSetRef::union), - "__or__" => context.new_rustfunc(PyFrozenSetRef::union), + "__or__" => context.new_rustfunc(PyFrozenSetRef::or), "intersection" => context.new_rustfunc(PyFrozenSetRef::intersection), "__and__" => context.new_rustfunc(PyFrozenSetRef::intersection), "difference" => context.new_rustfunc(PyFrozenSetRef::difference), From e011e3f32721fa1e7b3cd749152956cde336a5ee Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sun, 7 Apr 2019 22:36:56 +0300 Subject: [PATCH 07/11] intersection and difference support iterable --- tests/snippets/set.py | 8 +++ vm/src/obj/objset.rs | 142 +++++++++++++++++++++--------------------- 2 files changed, 80 insertions(+), 70 deletions(-) diff --git a/tests/snippets/set.py b/tests/snippets/set.py index 174426aa2b..a7977f35de 100644 --- a/tests/snippets/set.py +++ b/tests/snippets/set.py @@ -57,15 +57,19 @@ def __hash__(self): assert set([1,2,3]).intersection(set([1,2])) == set([1,2]) assert set([1,2,3]).intersection(set([5,6])) == set([]) +assert set([1,2,3]).intersection([1,2]) == set([1,2]) assert set([1,2,3]) & set([4,5]) == set([]) assert set([1,2,3]) & set([1,2,3,4,5]) == set([1,2,3]) +assert_raises(TypeError, lambda: set([1,2,3]) & [1,2,3,4,5]) assert set([1,2,3]).difference(set([1,2])) == set([3]) assert set([1,2,3]).difference(set([5,6])) == set([1,2,3]) +assert set([1,2,3]).difference([1,2]) == set([3]) assert set([1,2,3]) - set([4,5]) == set([1,2,3]) assert set([1,2,3]) - set([1,2,3,4,5]) == set([]) +assert_raises(TypeError, lambda: set([1,2,3]) - [1,2,3,4,5]) assert set([1,2,3]).symmetric_difference(set([1,2])) == set([3]) assert set([1,2,3]).symmetric_difference(set([5,6])) == set([1,2,3,5,6]) @@ -191,15 +195,19 @@ def __hash__(self): assert frozenset([1,2,3]).intersection(frozenset([1,2])) == frozenset([1,2]) assert frozenset([1,2,3]).intersection(frozenset([5,6])) == frozenset([]) +assert frozenset([1,2,3]).intersection([1,2]) == frozenset([1,2]) assert frozenset([1,2,3]) & frozenset([4,5]) == frozenset([]) assert frozenset([1,2,3]) & frozenset([1,2,3,4,5]) == frozenset([1,2,3]) +assert_raises(TypeError, lambda: frozenset([1,2,3]) & [1,2,3,4,5]) assert frozenset([1,2,3]).difference(frozenset([1,2])) == frozenset([3]) assert frozenset([1,2,3]).difference(frozenset([5,6])) == frozenset([1,2,3]) +assert frozenset([1,2,3]).difference([1,2]) == frozenset([3]) assert frozenset([1,2,3]) - frozenset([4,5]) == frozenset([1,2,3]) assert frozenset([1,2,3]) - frozenset([1,2,3,4,5]) == frozenset([]) +assert_raises(TypeError, lambda: frozenset([1,2,3]) - [1,2,3,4,5]) assert frozenset([1,2,3]).symmetric_difference(frozenset([1,2])) == frozenset([3]) assert frozenset([1,2,3]).symmetric_difference(frozenset([5,6])) == frozenset([1,2,3,5,6]) diff --git a/vm/src/obj/objset.rs b/vm/src/obj/objset.rs index 2aa00431bf..508400694f 100644 --- a/vm/src/obj/objset.rs +++ b/vm/src/obj/objset.rs @@ -187,33 +187,27 @@ impl PySetInner { Ok(PySetInner { elements }) } - fn intersection(&self, other: &PySetInner, vm: &VirtualMachine) -> PyResult { - self._combine_inner(other, vm, SetCombineOperation::Intersection) - } - - fn difference(&self, other: &PySetInner, vm: &VirtualMachine) -> PyResult { - self._combine_inner(other, vm, SetCombineOperation::Difference) - } - - fn _combine_inner( - &self, - other: &PySetInner, - vm: &VirtualMachine, - op: SetCombineOperation, - ) -> PyResult { + fn intersection(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult { let mut elements = HashMap::new(); - - for element in self.elements.iter() { - let value = other.contains(element.1.clone(), vm)?; - let should_add = match op { - SetCombineOperation::Intersection => objbool::get_value(&value), - SetCombineOperation::Difference => !objbool::get_value(&value), - }; - if should_add { - elements.insert(element.0.clone(), element.1.clone()); + for item in other.iter(vm)? { + if let Ok(obj) = item { + if objbool::get_value(&self.contains(obj.clone(), vm)?) { + insert_into_set(vm, &mut elements, &obj)?; + } } } + Ok(PySetInner { elements }) + } + fn difference(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult { + let mut elements = self.elements.clone(); + for item in other.iter(vm)? { + if let Ok(obj) = item { + if objbool::get_value(&self.contains(obj.clone(), vm)?) { + remove_from_set(vm, &mut elements, &obj)?; + } + } + } Ok(PySetInner { elements }) } @@ -265,21 +259,7 @@ impl PySetInner { } fn remove(&mut self, item: &PyObjectRef, vm: &VirtualMachine) -> PyResult { - fn remove( - vm: &VirtualMachine, - elements: &mut HashMap, - key: u64, - value: &PyObjectRef, - ) -> PyResult { - match elements.remove(&key) { - None => { - let item_str = format!("{:?}", value); - Err(vm.new_key_error(item_str)) - } - Some(_) => Ok(vm.get_none()), - } - } - perform_action_with_hash(vm, &mut self.elements, &item, &remove) + remove_from_set(vm, &mut self.elements, &item) } fn discard(&mut self, item: &PyObjectRef, vm: &VirtualMachine) -> PyResult { @@ -427,28 +407,20 @@ impl PySetRef { )) } - fn intersection(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn intersection(self, other: PyIterable, vm: &VirtualMachine) -> PyResult { Ok(PyObject::new( PySet { - inner: RefCell::new(match_class!(other, - set @ PySet => self.inner.borrow().intersection(&set.inner.borrow(), vm)?, - frozen @ PyFrozenSet => self.inner.borrow().intersection(&frozen.inner, vm)?, - other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, - )), + inner: RefCell::new(self.inner.borrow().intersection(other, vm)?), }, PySet::class(vm), None, )) } - fn difference(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn difference(self, other: PyIterable, vm: &VirtualMachine) -> PyResult { Ok(PyObject::new( PySet { - inner: RefCell::new(match_class!(other, - set @ PySet => self.inner.borrow().difference(&set.inner.borrow(), vm)?, - frozen @ PyFrozenSet => self.inner.borrow().difference(&frozen.inner, vm)?, - other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, - )), + inner: RefCell::new(self.inner.borrow().difference(other, vm)?), }, PySet::class(vm), None, @@ -473,6 +445,14 @@ impl PySetRef { self.union(other.iterable, vm) } + fn and(self, other: SetIterable, vm: &VirtualMachine) -> PyResult { + self.intersection(other.iterable, vm) + } + + fn sub(self, other: SetIterable, vm: &VirtualMachine) -> PyResult { + self.difference(other.iterable, vm) + } + fn iter(self, vm: &VirtualMachine) -> PyListIterator { self.inner.borrow().iter(vm) } @@ -642,32 +622,20 @@ impl PyFrozenSetRef { )) } - fn or(self, other: SetIterable, vm: &VirtualMachine) -> PyResult { - self.union(other.iterable, vm) - } - - fn intersection(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn intersection(self, other: PyIterable, vm: &VirtualMachine) -> PyResult { Ok(PyObject::new( PyFrozenSet { - inner: match_class!(other, - set @ PySet => self.inner.intersection(&set.inner.borrow(), vm)?, - frozen @ PyFrozenSet => self.inner.intersection(&frozen.inner, vm)?, - other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, - ), + inner: self.inner.intersection(other, vm)?, }, PyFrozenSet::class(vm), None, )) } - fn difference(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn difference(self, other: PyIterable, vm: &VirtualMachine) -> PyResult { Ok(PyObject::new( PyFrozenSet { - inner: match_class!(other, - set @ PySet => self.inner.difference(&set.inner.borrow(), vm)?, - frozen @ PyFrozenSet => self.inner.difference(&frozen.inner, vm)?, - other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, - ), + inner: self.inner.difference(other, vm)?, }, PyFrozenSet::class(vm), None, @@ -688,6 +656,18 @@ impl PyFrozenSetRef { )) } + fn or(self, other: SetIterable, vm: &VirtualMachine) -> PyResult { + self.union(other.iterable, vm) + } + + fn and(self, other: SetIterable, vm: &VirtualMachine) -> PyResult { + self.intersection(other.iterable, vm) + } + + fn sub(self, other: SetIterable, vm: &VirtualMachine) -> PyResult { + self.difference(other.iterable, vm) + } + fn iter(self, vm: &VirtualMachine) -> PyListIterator { self.inner.iter(vm) } @@ -737,6 +717,28 @@ fn insert_into_set( perform_action_with_hash(vm, elements, item, &insert) } +fn remove_from_set( + vm: &VirtualMachine, + elements: &mut HashMap, + item: &PyObjectRef, +) -> PyResult { + fn remove( + vm: &VirtualMachine, + elements: &mut HashMap, + key: u64, + value: &PyObjectRef, + ) -> PyResult { + match elements.remove(&key) { + None => { + let item_str = format!("{:?}", value); + Err(vm.new_key_error(item_str)) + } + Some(_) => Ok(vm.get_none()), + } + } + perform_action_with_hash(vm, elements, item, &remove) +} + enum SetCombineOperation { Intersection, Difference, @@ -790,9 +792,9 @@ pub fn init(context: &PyContext) { "union" => context.new_rustfunc(PySetRef::union), "__or__" => context.new_rustfunc(PySetRef::or), "intersection" => context.new_rustfunc(PySetRef::intersection), - "__and__" => context.new_rustfunc(PySetRef::intersection), + "__and__" => context.new_rustfunc(PySetRef::and), "difference" => context.new_rustfunc(PySetRef::difference), - "__sub__" => context.new_rustfunc(PySetRef::difference), + "__sub__" => context.new_rustfunc(PySetRef::sub), "symmetric_difference" => context.new_rustfunc(PySetRef::symmetric_difference), "__xor__" => context.new_rustfunc(PySetRef::symmetric_difference), "__doc__" => context.new_str(set_doc.to_string()), @@ -831,9 +833,9 @@ pub fn init(context: &PyContext) { "union" => context.new_rustfunc(PyFrozenSetRef::union), "__or__" => context.new_rustfunc(PyFrozenSetRef::or), "intersection" => context.new_rustfunc(PyFrozenSetRef::intersection), - "__and__" => context.new_rustfunc(PyFrozenSetRef::intersection), + "__and__" => context.new_rustfunc(PyFrozenSetRef::and), "difference" => context.new_rustfunc(PyFrozenSetRef::difference), - "__sub__" => context.new_rustfunc(PyFrozenSetRef::difference), + "__sub__" => context.new_rustfunc(PyFrozenSetRef::sub), "symmetric_difference" => context.new_rustfunc(PyFrozenSetRef::symmetric_difference), "__xor__" => context.new_rustfunc(PyFrozenSetRef::symmetric_difference), "__contains__" => context.new_rustfunc(PyFrozenSetRef::contains), From 3ed8727ee5b8252941d19e6c7e318510826ae903 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sun, 7 Apr 2019 23:01:49 +0300 Subject: [PATCH 08/11] symmetric_difference support iterable --- tests/snippets/set.py | 4 +++ vm/src/obj/objset.rs | 69 ++++++++++++++++++------------------------- 2 files changed, 33 insertions(+), 40 deletions(-) diff --git a/tests/snippets/set.py b/tests/snippets/set.py index a7977f35de..42237dd2e1 100644 --- a/tests/snippets/set.py +++ b/tests/snippets/set.py @@ -73,9 +73,11 @@ def __hash__(self): assert set([1,2,3]).symmetric_difference(set([1,2])) == set([3]) assert set([1,2,3]).symmetric_difference(set([5,6])) == set([1,2,3,5,6]) +assert set([1,2,3]).symmetric_difference([1,2]) == set([3]) assert set([1,2,3]) ^ set([4,5]) == set([1,2,3,4,5]) assert set([1,2,3]) ^ set([1,2,3,4,5]) == set([4,5]) +assert_raises(TypeError, lambda: set([1,2,3]) ^ [1,2,3,4,5]) assert_raises(TypeError, lambda: set([[]])) assert_raises(TypeError, lambda: set().add([])) @@ -211,9 +213,11 @@ def __hash__(self): assert frozenset([1,2,3]).symmetric_difference(frozenset([1,2])) == frozenset([3]) assert frozenset([1,2,3]).symmetric_difference(frozenset([5,6])) == frozenset([1,2,3,5,6]) +assert frozenset([1,2,3]).symmetric_difference([1,2]) == frozenset([3]) assert frozenset([1,2,3]) ^ frozenset([4,5]) == frozenset([1,2,3,4,5]) assert frozenset([1,2,3]) ^ frozenset([1,2,3,4,5]) == frozenset([4,5]) +assert_raises(TypeError, lambda: frozenset([1,2,3]) ^ [1,2,3,4,5]) assert_raises(TypeError, lambda: frozenset([[]])) diff --git a/vm/src/obj/objset.rs b/vm/src/obj/objset.rs index 508400694f..f8c37fdd65 100644 --- a/vm/src/obj/objset.rs +++ b/vm/src/obj/objset.rs @@ -190,10 +190,9 @@ impl PySetInner { fn intersection(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult { let mut elements = HashMap::new(); for item in other.iter(vm)? { - if let Ok(obj) = item { - if objbool::get_value(&self.contains(obj.clone(), vm)?) { - insert_into_set(vm, &mut elements, &obj)?; - } + let obj = item?; + if objbool::get_value(&self.contains(obj.clone(), vm)?) { + insert_into_set(vm, &mut elements, &obj)?; } } Ok(PySetInner { elements }) @@ -202,37 +201,27 @@ impl PySetInner { fn difference(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult { let mut elements = self.elements.clone(); for item in other.iter(vm)? { - if let Ok(obj) = item { - if objbool::get_value(&self.contains(obj.clone(), vm)?) { - remove_from_set(vm, &mut elements, &obj)?; - } + let obj = item?; + if objbool::get_value(&self.contains(obj.clone(), vm)?) { + remove_from_set(vm, &mut elements, &obj)?; } } Ok(PySetInner { elements }) } - fn symmetric_difference( - &self, - other: &PySetInner, - vm: &VirtualMachine, - ) -> PyResult { - let mut elements = HashMap::new(); + fn symmetric_difference(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult { + let mut new_inner = self.clone(); - for element in self.elements.iter() { - let value = other.contains(element.1.clone(), vm)?; - if !objbool::get_value(&value) { - elements.insert(element.0.clone(), element.1.clone()); - } - } - - for element in other.elements.iter() { - let value = self.contains(element.1.clone(), vm)?; - if !objbool::get_value(&value) { - elements.insert(element.0.clone(), element.1.clone()); + for item in other.iter(vm)? { + let obj = item?; + if !objbool::get_value(&self.contains(obj.clone(), vm)?) { + new_inner.add(&obj, vm)?; + } else { + new_inner.remove(&obj, vm)?; } } - Ok(PySetInner { elements }) + Ok(new_inner) } fn iter(&self, vm: &VirtualMachine) -> PyListIterator { @@ -427,14 +416,10 @@ impl PySetRef { )) } - fn symmetric_difference(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn symmetric_difference(self, other: PyIterable, vm: &VirtualMachine) -> PyResult { Ok(PyObject::new( PySet { - inner: RefCell::new(match_class!(other, - set @ PySet => self.inner.borrow().symmetric_difference(&set.inner.borrow(), vm)?, - frozen @ PyFrozenSet => self.inner.borrow().symmetric_difference(&frozen.inner, vm)?, - other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, - )), + inner: RefCell::new(self.inner.borrow().symmetric_difference(other, vm)?), }, PySet::class(vm), None, @@ -453,6 +438,10 @@ impl PySetRef { self.difference(other.iterable, vm) } + fn xor(self, other: SetIterable, vm: &VirtualMachine) -> PyResult { + self.symmetric_difference(other.iterable, vm) + } + fn iter(self, vm: &VirtualMachine) -> PyListIterator { self.inner.borrow().iter(vm) } @@ -642,14 +631,10 @@ impl PyFrozenSetRef { )) } - fn symmetric_difference(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn symmetric_difference(self, other: PyIterable, vm: &VirtualMachine) -> PyResult { Ok(PyObject::new( PyFrozenSet { - inner: match_class!(other, - set @ PySet => self.inner.symmetric_difference(&set.inner.borrow(), vm)?, - frozen @ PyFrozenSet => self.inner.symmetric_difference(&frozen.inner, vm)?, - other => {return Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class())));}, - ), + inner: self.inner.symmetric_difference(other, vm)?, }, PyFrozenSet::class(vm), None, @@ -668,6 +653,10 @@ impl PyFrozenSetRef { self.difference(other.iterable, vm) } + fn xor(self, other: SetIterable, vm: &VirtualMachine) -> PyResult { + self.symmetric_difference(other.iterable, vm) + } + fn iter(self, vm: &VirtualMachine) -> PyListIterator { self.inner.iter(vm) } @@ -796,7 +785,7 @@ pub fn init(context: &PyContext) { "difference" => context.new_rustfunc(PySetRef::difference), "__sub__" => context.new_rustfunc(PySetRef::sub), "symmetric_difference" => context.new_rustfunc(PySetRef::symmetric_difference), - "__xor__" => context.new_rustfunc(PySetRef::symmetric_difference), + "__xor__" => context.new_rustfunc(PySetRef::xor), "__doc__" => context.new_str(set_doc.to_string()), "add" => context.new_rustfunc(PySetRef::add), "remove" => context.new_rustfunc(PySetRef::remove), @@ -837,7 +826,7 @@ pub fn init(context: &PyContext) { "difference" => context.new_rustfunc(PyFrozenSetRef::difference), "__sub__" => context.new_rustfunc(PyFrozenSetRef::sub), "symmetric_difference" => context.new_rustfunc(PyFrozenSetRef::symmetric_difference), - "__xor__" => context.new_rustfunc(PyFrozenSetRef::symmetric_difference), + "__xor__" => context.new_rustfunc(PyFrozenSetRef::xor), "__contains__" => context.new_rustfunc(PyFrozenSetRef::contains), "__len__" => context.new_rustfunc(PyFrozenSetRef::len), "__doc__" => context.new_str(frozenset_doc.to_string()), From 848350d3347bac1bc404816a1dfefe0c908b7b3c Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sun, 7 Apr 2019 23:06:42 +0300 Subject: [PATCH 09/11] ior support only set and frozenset --- tests/snippets/set.py | 2 ++ vm/src/obj/objset.rs | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/snippets/set.py b/tests/snippets/set.py index 42237dd2e1..a44847f9b8 100644 --- a/tests/snippets/set.py +++ b/tests/snippets/set.py @@ -119,6 +119,8 @@ def __hash__(self): assert a == set([1,2,3,4,5]) with assertRaises(TypeError): a |= 1 +with assertRaises(TypeError): + a |= [1,2,3] a = set([1,2,3]) a.intersection_update([2,3,4,5]) diff --git a/vm/src/obj/objset.rs b/vm/src/obj/objset.rs index f8c37fdd65..888ffae530 100644 --- a/vm/src/obj/objset.rs +++ b/vm/src/obj/objset.rs @@ -478,8 +478,8 @@ impl PySetRef { self.inner.borrow_mut().pop(vm) } - fn ior(self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult { - self.inner.borrow_mut().update(iterable, vm)?; + fn ior(self, iterable: SetIterable, vm: &VirtualMachine) -> PyResult { + self.inner.borrow_mut().update(iterable.iterable, vm)?; Ok(self.as_object().clone()) } From 5ae921dc571273343d981a20f469ce2b64bc4303 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sun, 7 Apr 2019 23:26:59 +0300 Subject: [PATCH 10/11] update function support iterable --- tests/snippets/set.py | 6 +++ vm/src/obj/objset.rs | 106 +++++++++++++++++++----------------------- 2 files changed, 53 insertions(+), 59 deletions(-) diff --git a/tests/snippets/set.py b/tests/snippets/set.py index a44847f9b8..393db07549 100644 --- a/tests/snippets/set.py +++ b/tests/snippets/set.py @@ -132,6 +132,8 @@ def __hash__(self): assert a == set([2,3]) with assertRaises(TypeError): a &= 1 +with assertRaises(TypeError): + a &= [1,2,3] a = set([1,2,3]) a.difference_update([3,4,5]) @@ -143,6 +145,8 @@ def __hash__(self): assert a == set([1,2]) with assertRaises(TypeError): a -= 1 +with assertRaises(TypeError): + a -= [1,2,3] a = set([1,2,3]) a.symmetric_difference_update([3,4,5]) @@ -154,6 +158,8 @@ def __hash__(self): assert a == set([1,2,4,5]) with assertRaises(TypeError): a ^= 1 +with assertRaises(TypeError): + a ^= [1,2,3] # frozen set diff --git a/vm/src/obj/objset.rs b/vm/src/obj/objset.rs index 888ffae530..4b60b5b5cf 100644 --- a/vm/src/obj/objset.rs +++ b/vm/src/obj/objset.rs @@ -16,7 +16,6 @@ use crate::vm::{ReprGuard, VirtualMachine}; use super::objbool; use super::objint; -use super::objiter; use super::objlist::PyListIterator; use super::objtype; use super::objtype::PyClassRef; @@ -283,39 +282,41 @@ impl PySetInner { Ok(vm.get_none()) } - fn combine_update_inner( - &mut self, - iterable: &PyObjectRef, - vm: &VirtualMachine, - op: SetCombineOperation, - ) -> PyResult { - let elements = &mut self.elements; - for element in elements.clone().iter() { - let value = vm.call_method(iterable, "__contains__", vec![element.1.clone()])?; - let should_remove = match op { - SetCombineOperation::Intersection => !objbool::get_value(&value), - SetCombineOperation::Difference => objbool::get_value(&value), - }; - if should_remove { - elements.remove(&element.0.clone()); + fn intersection_update(&mut self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult { + let temp_inner = self.copy(); + self.clear(); + for item in iterable.iter(vm)? { + let obj = item?; + if objbool::get_value(&temp_inner.contains(obj.clone(), vm)?) { + self.add(&obj, vm)?; } } Ok(vm.get_none()) } - fn ixor(&mut self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let elements_original = self.elements.clone(); - let iterator = objiter::get_iter(vm, &iterable)?; - while let Ok(v) = vm.call_method(&iterator, "__next__", vec![]) { - insert_into_set(vm, &mut self.elements, &v)?; - } - for element in elements_original.iter() { - let value = vm.call_method(&iterable, "__contains__", vec![element.1.clone()])?; - if objbool::get_value(&value) { - self.elements.remove(&element.0.clone()); + fn difference_update(&mut self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult { + for item in iterable.iter(vm)? { + let obj = item?; + if objbool::get_value(&self.contains(obj.clone(), vm)?) { + self.remove(&obj, vm)?; } } + Ok(vm.get_none()) + } + fn symmetric_difference_update( + &mut self, + iterable: PyIterable, + vm: &VirtualMachine, + ) -> PyResult { + for item in iterable.iter(vm)? { + let obj = item?; + if !objbool::get_value(&self.contains(obj.clone(), vm)?) { + self.add(&obj, vm)?; + } else { + self.remove(&obj, vm)?; + } + } Ok(vm.get_none()) } } @@ -488,49 +489,41 @@ impl PySetRef { Ok(vm.get_none()) } - fn intersection_update(self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.inner.borrow_mut().combine_update_inner( - &iterable, - vm, - SetCombineOperation::Intersection, - )?; + fn intersection_update(self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult { + self.inner.borrow_mut().intersection_update(iterable, vm)?; Ok(vm.get_none()) } - fn iand(self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.inner.borrow_mut().combine_update_inner( - &iterable, - vm, - SetCombineOperation::Intersection, - )?; + fn iand(self, iterable: SetIterable, vm: &VirtualMachine) -> PyResult { + self.inner + .borrow_mut() + .intersection_update(iterable.iterable, vm)?; Ok(self.as_object().clone()) } - fn difference_update(self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.inner.borrow_mut().combine_update_inner( - &iterable, - vm, - SetCombineOperation::Difference, - )?; + fn difference_update(self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult { + self.inner.borrow_mut().difference_update(iterable, vm)?; Ok(vm.get_none()) } - fn isub(self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.inner.borrow_mut().combine_update_inner( - &iterable, - vm, - SetCombineOperation::Difference, - )?; + fn isub(self, iterable: SetIterable, vm: &VirtualMachine) -> PyResult { + self.inner + .borrow_mut() + .difference_update(iterable.iterable, vm)?; Ok(self.as_object().clone()) } - fn symmetric_difference_update(self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.inner.borrow_mut().ixor(iterable, vm)?; + fn symmetric_difference_update(self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult { + self.inner + .borrow_mut() + .symmetric_difference_update(iterable, vm)?; Ok(vm.get_none()) } - fn ixor(self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.inner.borrow_mut().ixor(iterable, vm)?; + fn ixor(self, iterable: SetIterable, vm: &VirtualMachine) -> PyResult { + self.inner + .borrow_mut() + .symmetric_difference_update(iterable.iterable, vm)?; Ok(self.as_object().clone()) } } @@ -728,11 +721,6 @@ fn remove_from_set( perform_action_with_hash(vm, elements, item, &remove) } -enum SetCombineOperation { - Intersection, - Difference, -} - struct SetIterable { iterable: PyIterable, } From abc72e999217e0de9318a04904562182a79ff62b Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sun, 7 Apr 2019 23:54:42 +0300 Subject: [PATCH 11/11] contains return bool --- vm/src/obj/objset.rs | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/vm/src/obj/objset.rs b/vm/src/obj/objset.rs index 4b60b5b5cf..5ec4fca46e 100644 --- a/vm/src/obj/objset.rs +++ b/vm/src/obj/objset.rs @@ -88,14 +88,14 @@ impl PySetInner { } } - fn contains(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn contains(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { for element in self.elements.iter() { let value = vm._eq(needle.clone(), element.1.clone())?; if objbool::get_value(&value) { - return Ok(vm.new_bool(true)); + return Ok(true); } } - Ok(vm.new_bool(false)) + Ok(false) } fn _compare_inner( @@ -124,8 +124,7 @@ impl PySetInner { return Ok(vm.new_bool(false)); } for element in get_other(swap).elements.iter() { - let value = get_zelf(swap).contains(element.1.clone(), vm)?; - if !objbool::get_value(&value) { + if !get_zelf(swap).contains(element.1.clone(), vm)? { return Ok(vm.new_bool(false)); } } @@ -190,7 +189,7 @@ impl PySetInner { let mut elements = HashMap::new(); for item in other.iter(vm)? { let obj = item?; - if objbool::get_value(&self.contains(obj.clone(), vm)?) { + if self.contains(obj.clone(), vm)? { insert_into_set(vm, &mut elements, &obj)?; } } @@ -201,7 +200,7 @@ impl PySetInner { let mut elements = self.elements.clone(); for item in other.iter(vm)? { let obj = item?; - if objbool::get_value(&self.contains(obj.clone(), vm)?) { + if self.contains(obj.clone(), vm)? { remove_from_set(vm, &mut elements, &obj)?; } } @@ -213,7 +212,7 @@ impl PySetInner { for item in other.iter(vm)? { let obj = item?; - if !objbool::get_value(&self.contains(obj.clone(), vm)?) { + if !self.contains(obj.clone(), vm)? { new_inner.add(&obj, vm)?; } else { new_inner.remove(&obj, vm)?; @@ -287,7 +286,7 @@ impl PySetInner { self.clear(); for item in iterable.iter(vm)? { let obj = item?; - if objbool::get_value(&temp_inner.contains(obj.clone(), vm)?) { + if temp_inner.contains(obj.clone(), vm)? { self.add(&obj, vm)?; } } @@ -297,7 +296,7 @@ impl PySetInner { fn difference_update(&mut self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult { for item in iterable.iter(vm)? { let obj = item?; - if objbool::get_value(&self.contains(obj.clone(), vm)?) { + if self.contains(obj.clone(), vm)? { self.remove(&obj, vm)?; } } @@ -311,7 +310,7 @@ impl PySetInner { ) -> PyResult { for item in iterable.iter(vm)? { let obj = item?; - if !objbool::get_value(&self.contains(obj.clone(), vm)?) { + if !self.contains(obj.clone(), vm)? { self.add(&obj, vm)?; } else { self.remove(&obj, vm)?; @@ -343,7 +342,7 @@ impl PySetRef { } } - fn contains(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn contains(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { self.inner.borrow().contains(needle, vm) } @@ -550,7 +549,7 @@ impl PyFrozenSetRef { } } - fn contains(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn contains(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { self.inner.contains(needle, vm) }