Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 89 additions & 51 deletions crates/vm/src/stdlib/ctypes/structure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ impl PyCStructType {

// Store StgInfo with aligned size and total alignment
let mut stg_info = StgInfo::new(aligned_size, total_align);
stg_info.length = fields.len();
stg_info.format = Some(format);
stg_info.flags |= StgInfoFlags::DICTFLAG_FINAL; // Mark as finalized
if has_pointer {
Expand Down Expand Up @@ -511,7 +512,7 @@ impl Debug for PyCStructure {
impl Constructor for PyCStructure {
type Args = FuncArgs;

fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
fn slot_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult {
// Check for abstract class and extract values in a block to drop the borrow
let (total_size, total_align, length) = {
let stg_info = cls.stg_info(vm)?;
Expand All @@ -523,79 +524,116 @@ impl Constructor for PyCStructure {
stg_info_mut.flags |= StgInfoFlags::DICTFLAG_FINAL;
}

// Get _fields_ from the class using get_attr to properly search MRO
let fields_attr = cls.as_object().get_attr("_fields_", vm).ok();
// Initialize buffer with zeros using computed size
let mut new_stg_info = StgInfo::new(total_size, total_align);
new_stg_info.length = length;
PyCStructure(PyCData::from_stg_info(&new_stg_info))
.into_ref_with_type(vm, cls)
.map(Into::into)
}

fn py_new(_cls: &Py<PyType>, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<Self> {
unimplemented!("use slot_new")
}
}

impl PyCStructure {
/// Recursively initialize positional arguments through inheritance chain
/// Returns the number of arguments consumed
fn init_pos_args(
self_obj: &Py<Self>,
type_obj: &Py<PyType>,
args: &[PyObjectRef],
kwargs: &indexmap::IndexMap<String, PyObjectRef>,
index: usize,
vm: &VirtualMachine,
) -> PyResult<usize> {
let mut current_index = index;

// Collect field names for initialization
let mut field_names: Vec<String> = Vec::new();
if let Some(fields_attr) = fields_attr {
let fields: Vec<PyObjectRef> = if let Some(list) = fields_attr.downcast_ref::<PyList>()
// 1. First process base class fields recursively
let base_clone = {
let bases = type_obj.bases.read();
if let Some(base) = bases.first()
&& base.stg_info_opt().is_some()
{
list.borrow_vec().to_vec()
} else if let Some(tuple) = fields_attr.downcast_ref::<PyTuple>() {
tuple.to_vec()
Some(base.clone())
} else {
vec![]
};
None
}
};

if let Some(ref base) = base_clone {
current_index = Self::init_pos_args(self_obj, base, args, kwargs, current_index, vm)?;
}

// 2. Process this class's _fields_
if let Some(fields_attr) = type_obj.get_direct_attr(vm.ctx.intern_str("_fields_")) {
let fields: Vec<PyObjectRef> = fields_attr.try_to_value(vm)?;

for field in fields.iter() {
let Some(field_tuple) = field.downcast_ref::<PyTuple>() else {
continue;
};
if field_tuple.len() < 2 {
continue;
if current_index >= args.len() {
break;
}
if let Some(name) = field_tuple.first().unwrap().downcast_ref::<PyStr>() {
field_names.push(name.to_string());
if let Some(tuple) = field.downcast_ref::<PyTuple>()
&& let Some(name) = tuple.first()
&& let Some(name_str) = name.downcast_ref::<PyStr>()
{
let field_name = name_str.as_str().to_owned();
// Check for duplicate in kwargs
if kwargs.contains_key(&field_name) {
return Err(vm.new_type_error(format!(
"duplicate values for field {:?}",
field_name
)));
}
self_obj.as_object().set_attr(
vm.ctx.intern_str(field_name),
args[current_index].clone(),
vm,
)?;
current_index += 1;
}
}
}

// Initialize buffer with zeros using computed size
let mut stg_info = StgInfo::new(total_size, total_align);
stg_info.length = if length > 0 {
length
} else {
field_names.len()
};
stg_info.paramfunc = super::base::ParamFunc::Structure;
let instance = PyCStructure(PyCData::from_stg_info(&stg_info));
Ok(current_index)
}
}

// Handle keyword arguments for field initialization
let py_instance = instance.into_ref_with_type(vm, cls.clone())?;
let py_obj: PyObjectRef = py_instance.clone().into();
impl Initializer for PyCStructure {
type Args = FuncArgs;

// Set field values from kwargs using standard attribute setting
for (key, value) in args.kwargs.iter() {
if field_names.iter().any(|n| n == key.as_str()) {
py_obj.set_attr(vm.ctx.intern_str(key.as_str()), value.clone(), vm)?;
fn init(zelf: crate::PyRef<Self>, args: Self::Args, vm: &VirtualMachine) -> PyResult<()> {
// Struct_init: handle positional and keyword arguments
let cls = zelf.class().to_owned();

// 1. Process positional arguments recursively through inheritance chain
if !args.args.is_empty() {
let consumed =
PyCStructure::init_pos_args(&zelf, &cls, &args.args, &args.kwargs, 0, vm)?;

if consumed < args.args.len() {
return Err(vm.new_type_error("too many initializers"));
}
}

// Set field values from positional args
if args.args.len() > field_names.len() {
return Err(vm.new_type_error("too many initializers".to_string()));
}
for (i, value) in args.args.iter().enumerate() {
py_obj.set_attr(
vm.ctx.intern_str(field_names[i].as_str()),
value.clone(),
vm,
)?;
// 2. Process keyword arguments
for (key, value) in args.kwargs.iter() {
zelf.as_object()
.set_attr(vm.ctx.intern_str(key.as_str()), value.clone(), vm)?;
}

Ok(py_instance.into())
}

fn py_new(_cls: &Py<PyType>, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<Self> {
unimplemented!("use slot_new")
Ok(())
}
}

// Note: GetAttr and SetAttr are not implemented here.
// Field access is handled by CField descriptors registered on the class.

#[pyclass(flags(BASETYPE, IMMUTABLETYPE), with(Constructor, AsBuffer))]
#[pyclass(
flags(BASETYPE, IMMUTABLETYPE),
with(Constructor, Initializer, AsBuffer)
)]
impl PyCStructure {
#[pygetset]
fn _b0_(&self) -> Option<PyObjectRef> {
Expand Down
8 changes: 5 additions & 3 deletions crates/vm/src/stdlib/ctypes/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ impl PyCUnionType {

// Store StgInfo with aligned size
let mut stg_info = StgInfo::new(aligned_size, total_align);
stg_info.length = fields.len();
stg_info.flags |= StgInfoFlags::DICTFLAG_FINAL | StgInfoFlags::TYPEFLAG_HASUNION;
// PEP 3118 doesn't support union. Use 'B' for bytes.
stg_info.format = Some("B".to_string());
Expand Down Expand Up @@ -431,9 +432,9 @@ impl Constructor for PyCUnion {

fn slot_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult {
// Check for abstract class and extract values in a block to drop the borrow
let (total_size, total_align) = {
let (total_size, total_align, length) = {
let stg_info = cls.stg_info(vm)?;
(stg_info.size, stg_info.align)
(stg_info.size, stg_info.align, stg_info.length)
};

// Mark the class as finalized (instance creation finalizes the type)
Expand All @@ -442,7 +443,8 @@ impl Constructor for PyCUnion {
}

// Initialize buffer with zeros using computed size
let new_stg_info = StgInfo::new(total_size, total_align);
let mut new_stg_info = StgInfo::new(total_size, total_align);
new_stg_info.length = length;
PyCUnion(PyCData::from_stg_info(&new_stg_info))
.into_ref_with_type(vm, cls)
.map(Into::into)
Expand Down