diff --git a/compiler/codegen/src/compile.rs b/compiler/codegen/src/compile.rs index 4ad592a0d2..9f4f3df82f 100644 --- a/compiler/codegen/src/compile.rs +++ b/compiler/codegen/src/compile.rs @@ -251,6 +251,26 @@ impl Compiler { } } + /// Push the next symbol table on to the stack + fn push_symbol_table(&mut self) -> &SymbolTable { + // Look up the next table contained in the scope of the current table + let table = self + .symbol_table_stack + .last_mut() + .expect("no next symbol table") + .sub_tables + .remove(0); + // Push the next table onto the stack + let last_idx = self.symbol_table_stack.len(); + self.symbol_table_stack.push(table); + &self.symbol_table_stack[last_idx] + } + + /// Pop the current symbol table off the stack + fn pop_symbol_table(&mut self) -> SymbolTable { + self.symbol_table_stack.pop().expect("compiler bug") + } + fn push_output( &mut self, flags: bytecode::CodeFlags, @@ -262,12 +282,7 @@ impl Compiler { let source_path = self.source_path.clone(); let first_line_number = self.get_source_line_number(); - let table = self - .symbol_table_stack - .last_mut() - .unwrap() - .sub_tables - .remove(0); + let table = self.push_symbol_table(); let cellvar_cache = table .symbols @@ -284,8 +299,6 @@ impl Compiler { .map(|(var, _)| var.clone()) .collect(); - self.symbol_table_stack.push(table); - let info = ir::CodeInfo { flags, posonlyarg_count, @@ -307,7 +320,7 @@ impl Compiler { } fn pop_code_object(&mut self) -> CodeObject { - let table = self.symbol_table_stack.pop().unwrap(); + let table = self.pop_symbol_table(); assert!(table.sub_tables.is_empty()); self.code_stack .pop() @@ -752,6 +765,7 @@ impl Compiler { body, decorator_list, returns, + type_params, .. }) => self.compile_function_def( name.as_str(), @@ -760,6 +774,7 @@ impl Compiler { decorator_list, returns.as_deref(), false, + type_params, )?, Stmt::AsyncFunctionDef(StmtAsyncFunctionDef { name, @@ -767,6 +782,7 @@ impl Compiler { body, decorator_list, returns, + type_params, .. }) => self.compile_function_def( name.as_str(), @@ -775,6 +791,7 @@ impl Compiler { decorator_list, returns.as_deref(), true, + type_params, )?, Stmt::ClassDef(StmtClassDef { name, @@ -782,8 +799,16 @@ impl Compiler { bases, keywords, decorator_list, + type_params, .. - }) => self.compile_class_def(name.as_str(), body, bases, keywords, decorator_list)?, + }) => self.compile_class_def( + name.as_str(), + body, + bases, + keywords, + decorator_list, + type_params, + )?, Stmt::Assert(StmtAssert { test, msg, .. }) => { // if some flag, ignore all assert statements! if self.opts.optimize == 0 { @@ -885,7 +910,27 @@ impl Compiler { Stmt::Pass(_) => { // No need to emit any code here :) } - Stmt::TypeAlias(_) => {} + Stmt::TypeAlias(StmtTypeAlias { + name, + type_params, + value, + .. + }) => { + let name_string = name.to_string(); + if !type_params.is_empty() { + self.push_symbol_table(); + } + self.compile_expression(value)?; + self.compile_type_params(type_params)?; + if !type_params.is_empty() { + self.pop_symbol_table(); + } + self.emit_load_const(ConstantData::Str { + value: name_string.clone(), + }); + emit!(self, Instruction::TypeAlias); + self.store_name(&name_string)?; + } } Ok(()) } @@ -1005,6 +1050,47 @@ impl Compiler { } } + /// Store each type parameter so it is accessible to the current scope, and leave a tuple of + /// all the type parameters on the stack. + fn compile_type_params(&mut self, type_params: &[located_ast::TypeParam]) -> CompileResult<()> { + for type_param in type_params { + match type_param { + located_ast::TypeParam::TypeVar(located_ast::TypeParamTypeVar { + name, + bound, + .. + }) => { + if let Some(expr) = &bound { + self.compile_expression(expr)?; + self.emit_load_const(ConstantData::Str { + value: name.to_string(), + }); + emit!(self, Instruction::TypeVarWithBound); + emit!(self, Instruction::Duplicate); + self.store_name(name.as_ref())?; + } else { + // self.store_name(type_name.as_str())?; + self.emit_load_const(ConstantData::Str { + value: name.to_string(), + }); + emit!(self, Instruction::TypeVar); + emit!(self, Instruction::Duplicate); + self.store_name(name.as_ref())?; + } + } + located_ast::TypeParam::ParamSpec(_) => todo!(), + located_ast::TypeParam::TypeVarTuple(_) => todo!(), + }; + } + emit!( + self, + Instruction::BuildTuple { + size: u32::try_from(type_params.len()).unwrap(), + } + ); + Ok(()) + } + fn compile_try_statement( &mut self, body: &[located_ast::Stmt], @@ -1151,6 +1237,7 @@ impl Compiler { is_forbidden_name(name) } + #[allow(clippy::too_many_arguments)] fn compile_function_def( &mut self, name: &str, @@ -1159,10 +1246,15 @@ impl Compiler { decorator_list: &[located_ast::Expr], returns: Option<&located_ast::Expr>, // TODO: use type hint somehow.. is_async: bool, + type_params: &[located_ast::TypeParam], ) -> CompileResult<()> { - // Create bytecode for this function: - self.prepare_decorators(decorator_list)?; + + // If there are type params, we need to push a special symbol table just for them + if !type_params.is_empty() { + self.push_symbol_table(); + } + let mut func_flags = self.enter_function(name, args)?; self.current_code_info() .flags @@ -1208,6 +1300,12 @@ impl Compiler { self.qualified_path.pop(); self.ctx = prev_ctx; + // Prepare generic type parameters: + if !type_params.is_empty() { + self.compile_type_params(type_params)?; + func_flags |= bytecode::MakeFunctionFlags::TYPE_PARAMS; + } + // Prepare type annotations: let mut num_annotations = 0; @@ -1253,6 +1351,11 @@ impl Compiler { func_flags |= bytecode::MakeFunctionFlags::CLOSURE; } + // Pop the special type params symbol table + if !type_params.is_empty() { + self.pop_symbol_table(); + } + self.emit_load_const(ConstantData::Code { code: Box::new(code), }); @@ -1352,6 +1455,7 @@ impl Compiler { bases: &[located_ast::Expr], keywords: &[located_ast::Keyword], decorator_list: &[located_ast::Expr], + type_params: &[located_ast::TypeParam], ) -> CompileResult<()> { self.prepare_decorators(decorator_list)?; @@ -1378,6 +1482,11 @@ impl Compiler { self.push_qualified_path(name); let qualified_name = self.qualified_path.join("."); + // If there are type params, we need to push a special symbol table just for them + if !type_params.is_empty() { + self.push_symbol_table(); + } + self.push_output(bytecode::CodeFlags::empty(), 0, 0, 0, name.to_owned()); let (doc_str, body) = split_doc(body, &self.opts); @@ -1428,10 +1537,21 @@ impl Compiler { let mut func_flags = bytecode::MakeFunctionFlags::empty(); + // Prepare generic type parameters: + if !type_params.is_empty() { + self.compile_type_params(type_params)?; + func_flags |= bytecode::MakeFunctionFlags::TYPE_PARAMS; + } + if self.build_closure(&code) { func_flags |= bytecode::MakeFunctionFlags::CLOSURE; } + // Pop the special type params symbol table + if !type_params.is_empty() { + self.pop_symbol_table(); + } + self.emit_load_const(ConstantData::Code { code: Box::new(code), }); diff --git a/compiler/codegen/src/symboltable.rs b/compiler/codegen/src/symboltable.rs index 75e327fe34..8522c82037 100644 --- a/compiler/codegen/src/symboltable.rs +++ b/compiler/codegen/src/symboltable.rs @@ -70,6 +70,7 @@ pub enum SymbolTableType { Class, Function, Comprehension, + TypeParams, } impl fmt::Display for SymbolTableType { @@ -79,6 +80,14 @@ impl fmt::Display for SymbolTableType { SymbolTableType::Class => write!(f, "class"), SymbolTableType::Function => write!(f, "function"), SymbolTableType::Comprehension => write!(f, "comprehension"), + SymbolTableType::TypeParams => write!(f, "type parameter"), + // TODO missing types from the C implementation + // if self._table.type == _symtable.TYPE_ANNOTATION: + // return "annotation" + // if self._table.type == _symtable.TYPE_TYPE_VAR_BOUND: + // return "TypeVar bound" + // if self._table.type == _symtable.TYPE_TYPE_ALIAS: + // return "type alias" } } } @@ -517,6 +526,9 @@ impl SymbolTableAnalyzer { self.analyze_symbol_comprehension(symbol, parent_offset + 1)?; } + SymbolTableType::TypeParams => { + todo!("analyze symbol comprehension for type params"); + } } Ok(()) } @@ -658,6 +670,7 @@ impl SymbolTableBuilder { body, args, decorator_list, + type_params, returns, range, .. @@ -667,6 +680,7 @@ impl SymbolTableBuilder { body, args, decorator_list, + type_params, returns, range, .. @@ -676,9 +690,20 @@ impl SymbolTableBuilder { if let Some(expression) = returns { self.scan_annotation(expression)?; } + if !type_params.is_empty() { + self.enter_scope( + &format!("", name.as_str()), + SymbolTableType::TypeParams, + range.start.row.get(), + ); + self.scan_type_params(type_params)?; + } self.enter_function(name.as_str(), args, range.start.row)?; self.scan_statements(body)?; self.leave_scope(); + if !type_params.is_empty() { + self.leave_scope(); + } } Stmt::ClassDef(StmtClassDef { name, @@ -686,9 +711,17 @@ impl SymbolTableBuilder { bases, keywords, decorator_list, - type_params: _, + type_params, range, }) => { + if !type_params.is_empty() { + self.enter_scope( + &format!("", name.as_str()), + SymbolTableType::TypeParams, + range.start.row.get(), + ); + self.scan_type_params(type_params)?; + } self.enter_scope(name.as_str(), SymbolTableType::Class, range.start.row.get()); let prev_class = std::mem::replace(&mut self.class_name, Some(name.to_string())); self.register_name("__module__", SymbolUsage::Assigned, range.start)?; @@ -702,6 +735,9 @@ impl SymbolTableBuilder { for keyword in keywords { self.scan_expression(&keyword.value, ExpressionContext::Load)?; } + if !type_params.is_empty() { + self.leave_scope(); + } self.scan_expressions(decorator_list, ExpressionContext::Load)?; self.register_name(name.as_str(), SymbolUsage::Assigned, range.start)?; } @@ -864,7 +900,26 @@ impl SymbolTableBuilder { self.scan_expression(expression, ExpressionContext::Load)?; } } - Stmt::TypeAlias(StmtTypeAlias { .. }) => {} + Stmt::TypeAlias(StmtTypeAlias { + name, + value, + type_params, + range, + }) => { + if !type_params.is_empty() { + self.enter_scope( + &name.to_string(), + SymbolTableType::TypeParams, + range.start.row.get(), + ); + self.scan_type_params(type_params)?; + self.scan_expression(value, ExpressionContext::Load)?; + self.leave_scope(); + } else { + self.scan_expression(value, ExpressionContext::Load)?; + } + self.scan_expression(name, ExpressionContext::Store)?; + } } Ok(()) } @@ -1187,6 +1242,26 @@ impl SymbolTableBuilder { Ok(()) } + fn scan_type_params(&mut self, type_params: &[ast::located::TypeParam]) -> SymbolTableResult { + for type_param in type_params { + match type_param { + ast::located::TypeParam::TypeVar(ast::TypeParamTypeVar { + name, + bound, + range: type_var_range, + }) => { + self.register_name(name.as_str(), SymbolUsage::Assigned, type_var_range.start)?; + if let Some(binding) = bound { + self.scan_expression(binding, ExpressionContext::Load)?; + } + } + ast::located::TypeParam::ParamSpec(_) => todo!(), + ast::located::TypeParam::TypeVarTuple(_) => todo!(), + } + } + Ok(()) + } + fn enter_function( &mut self, name: &str, diff --git a/compiler/core/src/bytecode.rs b/compiler/core/src/bytecode.rs index b00989ffe4..c8dbc63744 100644 --- a/compiler/core/src/bytecode.rs +++ b/compiler/core/src/bytecode.rs @@ -591,7 +591,14 @@ pub enum Instruction { GetANext, EndAsyncFor, ExtendedArg, + TypeVar, + TypeVarWithBound, + TypeVarWithConstraint, + TypeAlias, + // If you add a new instruction here, be sure to keep LAST_INSTRUCTION updated } +// This must be kept up to date to avoid marshaling errors +const LAST_INSTRUCTION: Instruction = Instruction::TypeAlias; const _: () = assert!(mem::size_of::() == 1); impl From for u8 { @@ -607,7 +614,7 @@ impl TryFrom for Instruction { #[inline] fn try_from(value: u8) -> Result { - if value <= u8::from(Instruction::ExtendedArg) { + if value <= u8::from(LAST_INSTRUCTION) { Ok(unsafe { std::mem::transmute::(value) }) } else { Err(crate::marshal::MarshalError::InvalidBytecode) @@ -639,6 +646,7 @@ bitflags! { const ANNOTATIONS = 0x02; const KW_ONLY_DEFAULTS = 0x04; const DEFAULTS = 0x08; + const TYPE_PARAMS = 0x10; } } impl OpArgType for MakeFunctionFlags { @@ -1279,6 +1287,10 @@ impl Instruction { GetANext => 1, EndAsyncFor => -2, ExtendedArg => 0, + TypeVar => 0, + TypeVarWithBound => -1, + TypeVarWithConstraint => -1, + TypeAlias => -2, } } @@ -1444,6 +1456,10 @@ impl Instruction { GetANext => w!(GetANext), EndAsyncFor => w!(EndAsyncFor), ExtendedArg => w!(ExtendedArg, Arg::::marker()), + TypeVar => w!(TypeVar), + TypeVarWithBound => w!(TypeVarWithBound), + TypeVarWithConstraint => w!(TypeVarWithConstraint), + TypeAlias => w!(TypeAlias), } } } diff --git a/vm/src/frame.rs b/vm/src/frame.rs index cc3e26585f..606ab83001 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -3,7 +3,7 @@ use crate::{ builtins::{ asyncgenerator::PyAsyncGenWrappedValue, function::{PyCell, PyCellRef, PyFunction}, - tuple::{PyTuple, PyTupleTyped}, + tuple::{PyTuple, PyTupleRef, PyTupleTyped}, PyBaseExceptionRef, PyCode, PyCoroutine, PyDict, PyDictRef, PyGenerator, PyList, PySet, PySlice, PyStr, PyStrInterned, PyStrRef, PyTraceback, PyType, }, @@ -15,7 +15,7 @@ use crate::{ protocol::{PyIter, PyIterReturn}, scope::Scope, source_code::SourceLocation, - stdlib::builtins, + stdlib::{builtins, typing::_typing}, vm::{Context, PyMethod}, AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, }; @@ -1162,6 +1162,46 @@ impl ExecutingFrame<'_> { *extend_arg = true; Ok(None) } + bytecode::Instruction::TypeVar => { + let type_name = self.pop_value(); + let type_var: PyObjectRef = + _typing::make_typevar(vm, type_name.clone(), vm.ctx.none(), vm.ctx.none()) + .into_ref(&vm.ctx) + .into(); + self.push_value(type_var); + Ok(None) + } + bytecode::Instruction::TypeVarWithBound => { + let type_name = self.pop_value(); + let bound = self.pop_value(); + let type_var: PyObjectRef = + _typing::make_typevar(vm, type_name.clone(), bound, vm.ctx.none()) + .into_ref(&vm.ctx) + .into(); + self.push_value(type_var); + Ok(None) + } + bytecode::Instruction::TypeVarWithConstraint => { + let type_name = self.pop_value(); + let constraint = self.pop_value(); + let type_var: PyObjectRef = + _typing::make_typevar(vm, type_name.clone(), vm.ctx.none(), constraint) + .into_ref(&vm.ctx) + .into(); + self.push_value(type_var); + Ok(None) + } + bytecode::Instruction::TypeAlias => { + let name = self.pop_value(); + let type_params: PyTupleRef = self + .pop_value() + .downcast() + .map_err(|_| vm.new_type_error("Type params must be a tuple.".to_owned()))?; + let value = self.pop_value(); + let type_alias = _typing::TypeAliasType::new(name, type_params, value); + self.push_value(type_alias.into_ref(&vm.ctx).into()); + Ok(None) + } } } @@ -1663,6 +1703,14 @@ impl ExecutingFrame<'_> { vm.ctx.new_dict().into() }; + let type_params: PyTupleRef = if flags.contains(bytecode::MakeFunctionFlags::TYPE_PARAMS) { + self.pop_value() + .downcast() + .map_err(|_| vm.new_type_error("Type params must be a tuple.".to_owned()))? + } else { + vm.ctx.empty_tuple.clone() + }; + let kw_only_defaults = if flags.contains(bytecode::MakeFunctionFlags::KW_ONLY_DEFAULTS) { Some( self.pop_value() @@ -1693,7 +1741,7 @@ impl ExecutingFrame<'_> { defaults, kw_only_defaults, qualified_name.clone(), - vm.ctx.empty_tuple.clone(), // FIXME: fake implementation + type_params, ) .into_pyobject(vm); diff --git a/vm/src/stdlib/mod.rs b/vm/src/stdlib/mod.rs index a6046ac72a..12baee11f7 100644 --- a/vm/src/stdlib/mod.rs +++ b/vm/src/stdlib/mod.rs @@ -21,6 +21,7 @@ mod sysconfigdata; #[cfg(feature = "threading")] pub mod thread; pub mod time; +pub mod typing; pub mod warnings; mod weakref; @@ -85,6 +86,7 @@ pub fn get_module_inits() -> StdlibMap { "_sre" => sre::make_module, "_string" => string::make_module, "time" => time::make_module, + "_typing" => typing::make_module, "_weakref" => weakref::make_module, "_imp" => imp::make_module, "_warnings" => warnings::make_module, diff --git a/vm/src/stdlib/typing.rs b/vm/src/stdlib/typing.rs new file mode 100644 index 0000000000..daa0180325 --- /dev/null +++ b/vm/src/stdlib/typing.rs @@ -0,0 +1,159 @@ +pub(crate) use _typing::make_module; + +#[pymodule] +pub(crate) mod _typing { + use crate::{ + builtins::{pystr::AsPyStr, PyGenericAlias, PyTupleRef, PyTypeRef}, + function::IntoFuncArgs, + PyObjectRef, PyPayload, PyResult, VirtualMachine, + }; + + pub(crate) fn _call_typing_func_object<'a>( + _vm: &VirtualMachine, + _func_name: impl AsPyStr<'a>, + _args: impl IntoFuncArgs, + ) -> PyResult { + todo!("does this work????"); + // let module = vm.import("typing", 0)?; + // let module = vm.import("_pycodecs", None, 0)?; + // let func = module.get_attr(func_name, vm)?; + // func.call(args, vm) + } + + #[pyattr] + pub(crate) fn _idfunc(_vm: &VirtualMachine) {} + + #[pyattr] + #[pyclass(name = "TypeVar")] + #[derive(Debug, PyPayload)] + #[allow(dead_code)] + pub(crate) struct TypeVar { + name: PyObjectRef, // TODO PyStrRef? + bound: parking_lot::Mutex, + evaluate_bound: PyObjectRef, + constraints: parking_lot::Mutex, + evaluate_constraints: PyObjectRef, + covariant: bool, + contravariant: bool, + infer_variance: bool, + } + #[pyclass(flags(BASETYPE))] + impl TypeVar { + pub(crate) fn _bound(&self, vm: &VirtualMachine) -> PyResult { + let mut bound = self.bound.lock(); + if !vm.is_none(&bound) { + return Ok(bound.clone()); + } + if !vm.is_none(&self.evaluate_bound) { + *bound = self.evaluate_bound.call((), vm)?; + Ok(bound.clone()) + } else { + Ok(vm.ctx.none()) + } + } + } + + pub(crate) fn make_typevar( + vm: &VirtualMachine, + name: PyObjectRef, + evaluate_bound: PyObjectRef, + evaluate_constraints: PyObjectRef, + ) -> TypeVar { + TypeVar { + name, + bound: parking_lot::Mutex::new(vm.ctx.none()), + evaluate_bound, + constraints: parking_lot::Mutex::new(vm.ctx.none()), + evaluate_constraints, + covariant: false, + contravariant: false, + infer_variance: true, + } + } + + #[pyattr] + #[pyclass(name = "ParamSpec")] + #[derive(Debug, PyPayload)] + #[allow(dead_code)] + struct ParamSpec {} + #[pyclass(flags(BASETYPE))] + impl ParamSpec {} + + #[pyattr] + #[pyclass(name = "TypeVarTuple")] + #[derive(Debug, PyPayload)] + #[allow(dead_code)] + pub(crate) struct TypeVarTuple {} + #[pyclass(flags(BASETYPE))] + impl TypeVarTuple {} + + #[pyattr] + #[pyclass(name = "ParamSpecArgs")] + #[derive(Debug, PyPayload)] + #[allow(dead_code)] + pub(crate) struct ParamSpecArgs {} + #[pyclass(flags(BASETYPE))] + impl ParamSpecArgs {} + + #[pyattr] + #[pyclass(name = "ParamSpecKwargs")] + #[derive(Debug, PyPayload)] + #[allow(dead_code)] + pub(crate) struct ParamSpecKwargs {} + #[pyclass(flags(BASETYPE))] + impl ParamSpecKwargs {} + + #[pyattr] + #[pyclass(name)] + #[derive(Debug, PyPayload)] + #[allow(dead_code)] + pub(crate) struct TypeAliasType { + name: PyObjectRef, // TODO PyStrRef? + type_params: PyTupleRef, + value: PyObjectRef, + // compute_value: PyObjectRef, + // module: PyObjectRef, + } + #[pyclass(flags(BASETYPE))] + impl TypeAliasType { + pub fn new( + name: PyObjectRef, + type_params: PyTupleRef, + value: PyObjectRef, + ) -> TypeAliasType { + TypeAliasType { + name, + type_params, + value, + } + } + } + + #[pyattr] + #[pyclass(name)] + #[derive(Debug, PyPayload)] + #[allow(dead_code)] + pub(crate) struct Generic {} + + // #[pyclass(with(AsMapping), flags(BASETYPE))] + #[pyclass(flags(BASETYPE))] + impl Generic { + #[pyclassmethod(magic)] + fn class_getitem(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias { + PyGenericAlias::new(cls, args, vm) + } + } + + // impl AsMapping for Generic { + // fn as_mapping() -> &'static PyMappingMethods { + // static AS_MAPPING: Lazy = Lazy::new(|| PyMappingMethods { + // subscript: atomic_func!(|mapping, needle, vm| { + // println!("gigity"); + // call_typing_func_object(vm, "_GenericAlias", (mapping.obj, needle)) + // }), + // ..PyMappingMethods::NOT_IMPLEMENTED + // }); + // &AS_MAPPING + // } + // } +} diff --git a/vm/src/vm/mod.rs b/vm/src/vm/mod.rs index 0edfe58ce8..c6ad7aefcf 100644 --- a/vm/src/vm/mod.rs +++ b/vm/src/vm/mod.rs @@ -283,6 +283,7 @@ impl VirtualMachine { stdlib::sys::init_module(self, &self.sys_module, &self.builtins); let mut essential_init = || -> PyResult { + import::import_builtin(self, "_typing")?; #[cfg(not(target_arch = "wasm32"))] import::import_builtin(self, "_signal")?; #[cfg(any(feature = "parser", feature = "compiler"))]