From 0609a975b80bc204de4173145e36e2e26bd2accb Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Thu, 23 Oct 2025 14:50:05 +0900 Subject: [PATCH 1/3] SSLSession --- Lib/ssl.py | 4 +- stdlib/src/ssl.rs | 218 ++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 212 insertions(+), 10 deletions(-) diff --git a/Lib/ssl.py b/Lib/ssl.py index 1200d7d993..6bd23a362f 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -98,7 +98,7 @@ import _ssl # if we can't import it, let the error propagate from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION -from _ssl import _SSLContext#, MemoryBIO, SSLSession +from _ssl import _SSLContext, SSLSession #, MemoryBIO from _ssl import ( SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError, SSLSyscallError, SSLEOFError, SSLCertVerificationError @@ -114,7 +114,7 @@ from _ssl import ( HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN, HAS_SSLv2, HAS_SSLv3, HAS_TLSv1, - HAS_TLSv1_1, HAS_TLSv1_2, HAS_TLSv1_3 + HAS_TLSv1_1, HAS_TLSv1_2, HAS_TLSv1_3, HAS_PSK ) from _ssl import _DEFAULT_CIPHERS, _OPENSSL_API_VERSION diff --git a/stdlib/src/ssl.rs b/stdlib/src/ssl.rs index 0e9de9c0dc..73b59fe082 100644 --- a/stdlib/src/ssl.rs +++ b/stdlib/src/ssl.rs @@ -38,15 +38,16 @@ mod _ssl { }, socket::{self, PySocket}, vm::{ - PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, + Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{PyBaseExceptionRef, PyStrRef, PyType, PyTypeRef, PyWeak}, + class_or_notimplemented, convert::{ToPyException, ToPyObject}, exceptions, function::{ ArgBytesLike, ArgCallable, ArgMemoryBuffer, ArgStrOrBytesLike, Either, FsPath, - OptionalArg, + OptionalArg, PyComparisonValue, }, - types::Constructor, + types::{Comparable, Constructor, PyComparisonOp}, utils::ToCString, }, }; @@ -162,6 +163,8 @@ mod _ssl { const HAS_TLSv1_2: bool = true; #[pyattr] const HAS_TLSv1_3: bool = cfg!(ossl111); + #[pyattr] + const HAS_PSK: bool = true; // the openssl version from the API headers @@ -816,16 +819,22 @@ mod _ssl { let stream = ssl::SslStream::new(ssl, SocketStream(args.sock.clone())) .map_err(|e| convert_openssl_error(vm, e))?; - // TODO: use this - let _ = args.session; - - Ok(PySslSocket { + let py_ssl_socket = PySslSocket { ctx: zelf, stream: PyRwLock::new(stream), socket_type, server_hostname: args.server_hostname, owner: PyRwLock::new(args.owner.map(|o| o.downgrade(None, vm)).transpose()?), - }) + }; + + // Set session if provided + if let Some(session) = args.session + && !vm.is_none(&session) + { + py_ssl_socket.set_session(session, vm)?; + } + + Ok(py_ssl_socket) } } @@ -1103,6 +1112,73 @@ mod _ssl { } } + #[pygetset] + fn session(&self, _vm: &VirtualMachine) -> PyResult> { + let stream = self.stream.read(); + unsafe { + let session_ptr = sys::SSL_get_session(stream.ssl().as_ptr()); + if session_ptr.is_null() { + Ok(None) + } else { + // Increment reference count since SSL_get_session returns a borrowed reference + #[cfg(ossl110)] + let _session = sys::SSL_SESSION_up_ref(session_ptr); + + Ok(Some(PySslSession { + session: session_ptr, + ctx: self.ctx.clone(), + })) + } + } + } + + #[pygetset(setter)] + fn set_session(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + // Check if value is SSLSession type + let session = value + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("Value is not a SSLSession.".to_owned()))?; + + // Check if session refers to the same SSLContext + if !std::ptr::eq( + self.ctx.ctx.read().as_ptr(), + session.ctx.ctx.read().as_ptr(), + ) { + return Err( + vm.new_value_error("Session refers to a different SSLContext.".to_owned()) + ); + } + + // Check if this is a client socket + if self.socket_type != SslServerOrClient::Client { + return Err( + vm.new_value_error("Cannot set session for server-side SSLSocket.".to_owned()) + ); + } + + // Check if handshake is not finished + let stream = self.stream.read(); + unsafe { + if sys::SSL_is_init_finished(stream.ssl().as_ptr()) != 0 { + return Err( + vm.new_value_error("Cannot set session after handshake.".to_owned()) + ); + } + + if sys::SSL_set_session(stream.ssl().as_ptr(), session.session) == 0 { + return Err(convert_openssl_error(vm, ErrorStack::get())); + } + } + + Ok(()) + } + + #[pygetset] + fn session_reused(&self) -> bool { + let stream = self.stream.read(); + unsafe { sys::SSL_session_reused(stream.ssl().as_ptr()) != 0 } + } + #[pymethod] fn read( &self, @@ -1164,6 +1240,132 @@ mod _ssl { } } + #[pyattr] + #[pyclass(module = "ssl", name = "SSLSession")] + #[derive(PyPayload)] + struct PySslSession { + session: *mut sys::SSL_SESSION, + ctx: PyRef, + } + + impl fmt::Debug for PySslSession { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("SSLSession") + } + } + + impl Drop for PySslSession { + fn drop(&mut self) { + if !self.session.is_null() { + unsafe { + sys::SSL_SESSION_free(self.session); + } + } + } + } + + unsafe impl Send for PySslSession {} + unsafe impl Sync for PySslSession {} + + impl Comparable for PySslSession { + fn cmp( + zelf: &Py, + other: &crate::vm::PyObject, + op: PyComparisonOp, + _vm: &VirtualMachine, + ) -> PyResult { + let other = class_or_notimplemented!(Self, other); + + if !matches!(op, PyComparisonOp::Eq | PyComparisonOp::Ne) { + return Ok(PyComparisonValue::NotImplemented); + } + let mut eq = unsafe { + let mut self_len: libc::c_uint = 0; + let mut other_len: libc::c_uint = 0; + let self_id = sys::SSL_SESSION_get_id(zelf.session, &mut self_len); + let other_id = sys::SSL_SESSION_get_id(other.session, &mut other_len); + + if self_len != other_len { + false + } else { + let self_slice = std::slice::from_raw_parts(self_id, self_len as usize); + let other_slice = std::slice::from_raw_parts(other_id, other_len as usize); + self_slice == other_slice + } + }; + if matches!(op, PyComparisonOp::Ne) { + eq = !eq; + } + Ok(PyComparisonValue::Implemented(eq)) + } + } + + #[pyclass(with(Comparable))] + impl PySslSession { + #[pygetset] + fn time(&self) -> i64 { + unsafe { + #[cfg(ossl330)] + { + sys::SSL_SESSION_get_time(self.session) as i64 + } + #[cfg(not(ossl330))] + { + sys::SSL_SESSION_get_time(self.session) as i64 + } + } + } + + #[pygetset] + fn timeout(&self) -> i64 { + unsafe { sys::SSL_SESSION_get_timeout(self.session) as i64 } + } + + #[pygetset] + fn ticket_lifetime_hint(&self) -> u64 { + // SSL_SESSION_get_ticket_lifetime_hint may not be available in older OpenSSL + // Return 0 as default if not available + #[cfg(ossl110)] + { + // For now, return 0 as this function may not be in openssl-sys + let _ = self.session; + 0 + } + #[cfg(not(ossl110))] + { + let _ = self.session; + 0 + } + } + + #[pygetset] + fn id(&self, vm: &VirtualMachine) -> PyObjectRef { + unsafe { + let mut len: libc::c_uint = 0; + let id_ptr = sys::SSL_SESSION_get_id(self.session, &mut len); + let id_slice = std::slice::from_raw_parts(id_ptr, len as usize); + vm.ctx.new_bytes(id_slice.to_vec()).into() + } + } + + #[pygetset] + fn has_ticket(&self) -> bool { + // SSL_SESSION_has_ticket may not be available in older OpenSSL + // Return false as default + #[cfg(ossl110)] + { + // For now, return false as this function may not be in openssl-sys + let _ = self.session; + false + } + #[cfg(not(ossl110))] + { + let _ = self.session; + false + } + } + } + #[track_caller] fn convert_openssl_error(vm: &VirtualMachine, err: ErrorStack) -> PyBaseExceptionRef { let cls = ssl_error(vm); From 58a8219d841c4540537e9251003847e517df903e Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Thu, 23 Oct 2025 17:05:53 +0900 Subject: [PATCH 2/3] get_unverified_chain --- stdlib/src/ssl.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/stdlib/src/ssl.rs b/stdlib/src/ssl.rs index 73b59fe082..bfb49c7edf 100644 --- a/stdlib/src/ssl.rs +++ b/stdlib/src/ssl.rs @@ -1005,6 +1005,19 @@ mod _ssl { .transpose() } + #[pymethod] + fn get_unverified_chain(&self, vm: &VirtualMachine) -> Option { + let stream = self.stream.read(); + let chain = stream.ssl().peer_cert_chain()?; + + let certs: Vec = chain + .iter() + .filter_map(|cert| cert.to_der().ok().map(|der| vm.ctx.new_bytes(der).into())) + .collect(); + + Some(vm.ctx.new_list(certs).into()) + } + #[pymethod] fn version(&self) -> Option<&'static str> { let v = self.stream.read().ssl().version_str(); From 75c18e1d222cba8778a9f83bdd60d8fade7dde54 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Thu, 23 Oct 2025 15:34:56 +0900 Subject: [PATCH 3/3] SSL MemoryBIO --- Lib/ssl.py | 2 +- stdlib/src/ssl.rs | 173 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 174 insertions(+), 1 deletion(-) diff --git a/Lib/ssl.py b/Lib/ssl.py index 6bd23a362f..751d79fb5e 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -98,7 +98,7 @@ import _ssl # if we can't import it, let the error propagate from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION -from _ssl import _SSLContext, SSLSession #, MemoryBIO +from _ssl import _SSLContext, SSLSession, MemoryBIO from _ssl import ( SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError, SSLSyscallError, SSLEOFError, SSLCertVerificationError diff --git a/stdlib/src/ssl.rs b/stdlib/src/ssl.rs index bfb49c7edf..c9a9e15f8e 100644 --- a/stdlib/src/ssl.rs +++ b/stdlib/src/ssl.rs @@ -836,6 +836,29 @@ mod _ssl { Ok(py_ssl_socket) } + + #[pymethod] + fn _wrap_bio(_zelf: PyRef, _args: WrapBioArgs, vm: &VirtualMachine) -> PyResult { + // TODO: Implement BIO-based SSL wrapping + // This requires refactoring PySslSocket to support both socket and BIO modes + Err(vm.new_not_implemented_error( + "_wrap_bio is not yet implemented in RustPython".to_owned(), + )) + } + } + + #[derive(FromArgs)] + #[allow(dead_code)] // Fields will be used when _wrap_bio is fully implemented + struct WrapBioArgs { + incoming: PyRef, + outgoing: PyRef, + server_side: bool, + #[pyarg(any, default)] + server_hostname: Option, + #[pyarg(named, default)] + owner: Option, + #[pyarg(named, default)] + session: Option, } #[derive(FromArgs)] @@ -1313,6 +1336,156 @@ mod _ssl { } } + #[pyattr] + #[pyclass(module = "ssl", name = "MemoryBIO")] + #[derive(PyPayload)] + struct PySslMemoryBio { + bio: *mut sys::BIO, + eof_written: AtomicCell, + } + + impl fmt::Debug for PySslMemoryBio { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("MemoryBIO") + } + } + + impl Drop for PySslMemoryBio { + fn drop(&mut self) { + if !self.bio.is_null() { + unsafe { + sys::BIO_free_all(self.bio); + } + } + } + } + + unsafe impl Send for PySslMemoryBio {} + unsafe impl Sync for PySslMemoryBio {} + + // OpenSSL BIO helper functions + // These are typically macros in OpenSSL, implemented via BIO_ctrl + const BIO_CTRL_PENDING: libc::c_int = 10; + const BIO_CTRL_SET_EOF: libc::c_int = 2; + + #[allow(non_snake_case)] + unsafe fn BIO_ctrl_pending(bio: *mut sys::BIO) -> usize { + unsafe { sys::BIO_ctrl(bio, BIO_CTRL_PENDING, 0, std::ptr::null_mut()) as usize } + } + + #[allow(non_snake_case)] + unsafe fn BIO_set_mem_eof_return(bio: *mut sys::BIO, eof: libc::c_int) -> libc::c_int { + unsafe { + sys::BIO_ctrl( + bio, + BIO_CTRL_SET_EOF, + eof as libc::c_long, + std::ptr::null_mut(), + ) as libc::c_int + } + } + + #[allow(non_snake_case)] + unsafe fn BIO_clear_retry_flags(bio: *mut sys::BIO) { + unsafe { + sys::BIO_clear_flags(bio, sys::BIO_FLAGS_RWS | sys::BIO_FLAGS_SHOULD_RETRY); + } + } + + impl Constructor for PySslMemoryBio { + type Args = (); + + fn py_new(cls: PyTypeRef, _args: Self::Args, vm: &VirtualMachine) -> PyResult { + unsafe { + let bio = sys::BIO_new(sys::BIO_s_mem()); + if bio.is_null() { + return Err(vm.new_memory_error("failed to allocate BIO".to_owned())); + } + + sys::BIO_set_retry_read(bio); + BIO_set_mem_eof_return(bio, -1); + + PySslMemoryBio { + bio, + eof_written: AtomicCell::new(false), + } + .into_ref_with_type(vm, cls) + .map(Into::into) + } + } + } + + #[pyclass(with(Constructor))] + impl PySslMemoryBio { + #[pygetset] + fn pending(&self) -> usize { + unsafe { BIO_ctrl_pending(self.bio) } + } + + #[pygetset] + fn eof(&self) -> bool { + let pending = unsafe { BIO_ctrl_pending(self.bio) }; + pending == 0 && self.eof_written.load() + } + + #[pymethod] + fn read(&self, size: OptionalArg, vm: &VirtualMachine) -> PyResult> { + unsafe { + let avail = BIO_ctrl_pending(self.bio).min(i32::MAX as usize) as i32; + let len = size.unwrap_or(-1); + let len = if len < 0 || len > avail { avail } else { len }; + + if len == 0 { + return Ok(Vec::new()); + } + + let mut buf = vec![0u8; len as usize]; + let nbytes = sys::BIO_read(self.bio, buf.as_mut_ptr() as *mut _, len); + + if nbytes < 0 { + return Err(convert_openssl_error(vm, ErrorStack::get())); + } + + buf.truncate(nbytes as usize); + Ok(buf) + } + } + + #[pymethod] + fn write(&self, data: ArgBytesLike, vm: &VirtualMachine) -> PyResult { + if self.eof_written.load() { + return Err(vm.new_exception_msg( + ssl_error(vm), + "cannot write() after write_eof()".to_owned(), + )); + } + + data.with_ref(|buf| unsafe { + if buf.len() > i32::MAX as usize { + return Err( + vm.new_overflow_error(format!("string longer than {} bytes", i32::MAX)) + ); + } + + let nbytes = sys::BIO_write(self.bio, buf.as_ptr() as *const _, buf.len() as i32); + if nbytes < 0 { + return Err(convert_openssl_error(vm, ErrorStack::get())); + } + + Ok(nbytes) + }) + } + + #[pymethod] + fn write_eof(&self) { + self.eof_written.store(true); + unsafe { + BIO_clear_retry_flags(self.bio); + BIO_set_mem_eof_return(self.bio, 0); + } + } + } + #[pyclass(with(Comparable))] impl PySslSession { #[pygetset]