diff --git a/sdk/python/feast/feature_server.py b/sdk/python/feast/feature_server.py index fbbb38821af..b06920ebe0e 100644 --- a/sdk/python/feast/feature_server.py +++ b/sdk/python/feast/feature_server.py @@ -631,6 +631,7 @@ def _add_mcp_support_if_enabled(app, store: "feast.FeatureStore"): class FeastServeApplication(gunicorn.app.base.BaseApplication): def __init__(self, store: "feast.FeatureStore", **options): + self._store = store self._app = get_app( store=store, registry_ttl_sec=options["registry_ttl_sec"], @@ -645,6 +646,24 @@ def load_config(self): self.cfg.set("worker_class", "uvicorn_worker.UvicornWorker") + # Register post_fork hook for fork-safety with SQL Registry + # This ensures each worker reinitializes database connections + # and background threads after forking + self.cfg.set("post_fork", self._post_fork_hook) + + def _post_fork_hook(self, server, worker): + """ + Gunicorn post_fork hook called in each worker after fork. + + This is critical for fork-safety when using SQL Registry backends. + SQLAlchemy connection pools and threading.Timer objects are not + fork-safe and must be reinitialized in each worker process. + """ + logger.debug(f"Worker {worker.pid} initializing after fork") + if hasattr(self._store, "registry") and self._store.registry is not None: + self._store.registry.on_worker_init() + logger.debug(f"Worker {worker.pid} registry reinitialized") + def load(self): return self._app diff --git a/sdk/python/feast/infra/offline_stores/duckdb.py b/sdk/python/feast/infra/offline_stores/duckdb.py index 7bf96129d0b..cdcd14dcc2b 100644 --- a/sdk/python/feast/infra/offline_stores/duckdb.py +++ b/sdk/python/feast/infra/offline_stores/duckdb.py @@ -1,7 +1,8 @@ +import logging import os from datetime import datetime from pathlib import Path -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Literal, Optional, Union import ibis import pandas as pd @@ -26,8 +27,96 @@ from feast.infra.registry.base_registry import BaseRegistry from feast.repo_config import FeastConfigBaseModel, RepoConfig +logger = logging.getLogger(__name__) + +# Track whether S3 has been configured for the current DuckDB connection +_s3_configured = False + + +def _configure_duckdb_for_s3(config: "DuckDBOfflineStoreConfig") -> None: + """ + Configure the DuckDB connection for S3 access. + + This function configures DuckDB's HTTPFS extension with S3 settings. + It's designed to be called once before any S3 operations. + + Args: + config: DuckDB offline store configuration containing S3 settings. + """ + global _s3_configured + + # Check if any S3 settings are configured + has_s3_settings = any( + [ + config.s3_url_style, + config.s3_endpoint, + config.s3_access_key_id, + config.s3_secret_access_key, + config.s3_region, + config.s3_use_ssl is not None, + ] + ) + + if not has_s3_settings: + return + + if _s3_configured: + return + + try: + # Get the default DuckDB connection from ibis + con = ibis.get_backend() + + # Install and load httpfs extension for S3 support + con.raw_sql("INSTALL httpfs;") + con.raw_sql("LOAD httpfs;") + + # Configure S3 settings + if config.s3_url_style: + con.raw_sql(f"SET s3_url_style='{config.s3_url_style}';") + logger.debug(f"DuckDB S3 url_style set to '{config.s3_url_style}'") + + if config.s3_endpoint: + con.raw_sql(f"SET s3_endpoint='{config.s3_endpoint}';") + logger.debug(f"DuckDB S3 endpoint set to '{config.s3_endpoint}'") + + if config.s3_access_key_id: + con.raw_sql(f"SET s3_access_key_id='{config.s3_access_key_id}';") + logger.debug("DuckDB S3 access_key_id configured") + + if config.s3_secret_access_key: + con.raw_sql(f"SET s3_secret_access_key='{config.s3_secret_access_key}';") + logger.debug("DuckDB S3 secret_access_key configured") + + if config.s3_region: + con.raw_sql(f"SET s3_region='{config.s3_region}';") + logger.debug(f"DuckDB S3 region set to '{config.s3_region}'") + + if config.s3_use_ssl is not None: + ssl_value = "true" if config.s3_use_ssl else "false" + con.raw_sql(f"SET s3_use_ssl={ssl_value};") + logger.debug(f"DuckDB S3 use_ssl set to {ssl_value}") + + _s3_configured = True + logger.info("DuckDB S3 configuration completed successfully") + + except Exception as e: + logger.warning(f"Failed to configure DuckDB for S3: {e}") + # Don't raise - let the operation continue and potentially fail with a more specific error + + +def _is_s3_path(path: str) -> bool: + """Check if the given path is an S3 path.""" + return path.startswith("s3://") or path.startswith("s3a://") + def _read_data_source(data_source: DataSource, repo_path: str) -> Table: + """ + Read data from a FileSource into an ibis Table. + + Note: S3 configuration must be set up before calling this function + by calling _configure_duckdb_for_s3() from the DuckDBOfflineStore methods. + """ assert isinstance(data_source, FileSource) if isinstance(data_source.file_format, ParquetFormat): @@ -113,12 +202,66 @@ def _write_data_source( class DuckDBOfflineStoreConfig(FeastConfigBaseModel): + """Configuration for DuckDB offline store. + + Attributes: + type: Offline store type selector. Must be "duckdb". + staging_location: Optional S3 path for staging data during remote exports. + staging_location_endpoint_override: Custom S3 endpoint for staging location. + s3_url_style: S3 URL style - "path" for path-style URLs (required for + MinIO, LocalStack, etc.) or "vhost" for virtual-hosted style. + Default is None which uses DuckDB's default (vhost). + s3_endpoint: Custom S3 endpoint URL (e.g., "localhost:9000" for MinIO). + s3_access_key_id: AWS access key ID for S3 authentication. + If not set, uses AWS credential chain. + s3_secret_access_key: AWS secret access key for S3 authentication. + If not set, uses AWS credential chain. + s3_region: AWS region for S3 access (e.g., "us-east-1"). + Required for some S3-compatible providers. + s3_use_ssl: Whether to use SSL for S3 connections. + Default is None which uses DuckDB's default (true). + + Example: + For MinIO or other S3-compatible storage that requires path-style URLs: + + .. code-block:: yaml + + offline_store: + type: duckdb + s3_url_style: path + s3_endpoint: localhost:9000 + s3_access_key_id: minioadmin + s3_secret_access_key: minioadmin + s3_region: us-east-1 + s3_use_ssl: false + """ + type: StrictStr = "duckdb" - # """ Offline store type selector""" staging_location: Optional[str] = None + """S3 path for staging data during remote exports.""" staging_location_endpoint_override: Optional[str] = None + """Custom S3 endpoint for staging location.""" + + # S3 configuration options for DuckDB's HTTPFS extension + s3_url_style: Optional[Literal["path", "vhost"]] = None + """S3 URL style - 'path' for path-style or 'vhost' for virtual-hosted style.""" + + s3_endpoint: Optional[str] = None + """Custom S3 endpoint URL (e.g., 'localhost:9000' for MinIO).""" + + s3_access_key_id: Optional[str] = None + """AWS access key ID. If not set, uses AWS credential chain.""" + + s3_secret_access_key: Optional[str] = None + """AWS secret access key. If not set, uses AWS credential chain.""" + + s3_region: Optional[str] = None + """AWS region (e.g., 'us-east-1'). Required for some S3-compatible providers.""" + + s3_use_ssl: Optional[bool] = None + """Whether to use SSL for S3 connections. Default uses DuckDB's default (true).""" class DuckDBOfflineStore(OfflineStore): @@ -133,6 +276,9 @@ def pull_latest_from_table_or_query( start_date: datetime, end_date: datetime, ) -> RetrievalJob: + # Configure S3 settings for DuckDB before reading data + _configure_duckdb_for_s3(config.offline_store) + return pull_latest_from_table_or_query_ibis( config=config, data_source=data_source, @@ -158,6 +304,9 @@ def get_historical_features( project: str, full_feature_names: bool = False, ) -> RetrievalJob: + # Configure S3 settings for DuckDB before reading data + _configure_duckdb_for_s3(config.offline_store) + return get_historical_features_ibis( config=config, feature_views=feature_views, @@ -183,6 +332,9 @@ def pull_all_from_table_or_query( start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, ) -> RetrievalJob: + # Configure S3 settings for DuckDB before reading data + _configure_duckdb_for_s3(config.offline_store) + return pull_all_from_table_or_query_ibis( config=config, data_source=data_source, @@ -205,6 +357,9 @@ def offline_write_batch( table: pyarrow.Table, progress: Optional[Callable[[int], Any]], ): + # Configure S3 settings for DuckDB before writing data + _configure_duckdb_for_s3(config.offline_store) + offline_write_batch_ibis( config=config, feature_view=feature_view, @@ -221,6 +376,9 @@ def write_logged_features( logging_config: LoggingConfig, registry: BaseRegistry, ): + # Configure S3 settings for DuckDB before writing data + _configure_duckdb_for_s3(config.offline_store) + write_logged_features_ibis( config=config, data=data, diff --git a/sdk/python/feast/infra/registry/base_registry.py b/sdk/python/feast/infra/registry/base_registry.py index 24e9f36fbd2..ea29a74a69e 100644 --- a/sdk/python/feast/infra/registry/base_registry.py +++ b/sdk/python/feast/infra/registry/base_registry.py @@ -805,6 +805,23 @@ def refresh(self, project: Optional[str] = None): """Refreshes the state of the registry cache by fetching the registry state from the remote registry store.""" raise NotImplementedError + def on_worker_init(self): + """ + Called after a worker process has been forked to reinitialize resources. + + This method is critical for fork-safety when using multi-worker servers + like Gunicorn. Resources like database connection pools, threads, and + file handles are not fork-safe and must be reinitialized in child processes. + + Subclasses should override this method to: + - Dispose and recreate database connection pools + - Stop and restart background threads + - Reinitialize any other fork-unsafe resources + + This is a no-op by default for registries that don't need special handling. + """ + pass + # Lineage operations def get_registry_lineage( self, diff --git a/sdk/python/feast/infra/registry/caching_registry.py b/sdk/python/feast/infra/registry/caching_registry.py index ce346272af9..2ca0769012e 100644 --- a/sdk/python/feast/infra/registry/caching_registry.py +++ b/sdk/python/feast/infra/registry/caching_registry.py @@ -487,3 +487,29 @@ def _start_thread_async_refresh(self, cache_ttl_seconds): def _exit_handler(self): self.registry_refresh_thread.cancel() + + def on_worker_init(self): + """ + Called after a worker process has been forked to reinitialize resources. + + For CachingRegistry, this method: + 1. Cancels any inherited refresh thread from the parent process + 2. Restarts the refresh thread if cache_mode is "thread" + + This ensures each worker has its own independent refresh thread. + """ + # Cancel any inherited timer from parent process + if hasattr(self, "registry_refresh_thread") and self.registry_refresh_thread: + try: + self.registry_refresh_thread.cancel() + except Exception: + pass # Timer may already be invalid after fork + + # Restart refresh thread if using thread cache mode + if self.cache_mode == "thread": + cache_ttl_seconds = int(self.cached_registry_proto_ttl.total_seconds()) + if cache_ttl_seconds > 0: + self._start_thread_async_refresh(cache_ttl_seconds) + logger.debug( + "CachingRegistry refresh thread restarted after fork" + ) diff --git a/sdk/python/feast/infra/registry/sql.py b/sdk/python/feast/infra/registry/sql.py index 103b1f6c0a6..9ea3475d7d4 100644 --- a/sdk/python/feast/infra/registry/sql.py +++ b/sdk/python/feast/infra/registry/sql.py @@ -250,17 +250,7 @@ def __init__( ), "SqlRegistry needs a valid registry_config" self.registry_config = registry_config - - self.write_engine: Engine = create_engine( - registry_config.path, **registry_config.sqlalchemy_config_kwargs - ) - if registry_config.read_path: - self.read_engine: Engine = create_engine( - registry_config.read_path, - **registry_config.sqlalchemy_config_kwargs, - ) - else: - self.read_engine = self.write_engine + self._init_engines() metadata.create_all(self.write_engine) self.thread_pool_executor_worker_count = ( registry_config.thread_pool_executor_worker_count @@ -278,6 +268,62 @@ def __init__( if not self.purge_feast_metadata: self._maybe_init_project_metadata(project) + def _init_engines(self): + """Initialize SQLAlchemy engines. Can be called to reinitialize after fork.""" + self.write_engine: Engine = create_engine( + self.registry_config.path, **self.registry_config.sqlalchemy_config_kwargs + ) + if self.registry_config.read_path: + self.read_engine: Engine = create_engine( + self.registry_config.read_path, + **self.registry_config.sqlalchemy_config_kwargs, + ) + else: + self.read_engine = self.write_engine + + def reinitialize_engines(self): + """ + Reinitialize SQLAlchemy engines after a process fork. + + This method is critical for fork-safety when using multi-worker servers + like Gunicorn. SQLAlchemy's connection pools are not fork-safe - the + internal state (locks, conditions, connections) becomes corrupted when + a process forks. This method disposes the old engines and creates fresh + ones with new connection pools. + + Should be called in a post_fork hook when running with multiple workers. + """ + # Dispose existing engines to clean up connection pools + if hasattr(self, "write_engine") and self.write_engine is not None: + self.write_engine.dispose() + if ( + hasattr(self, "read_engine") + and self.read_engine is not None + and self.read_engine is not self.write_engine + ): + self.read_engine.dispose() + + # Reinitialize with fresh engines + self._init_engines() + logger.debug("SqlRegistry engines reinitialized after fork") + + def on_worker_init(self): + """ + Called after a worker process has been forked to reinitialize resources. + + For SqlRegistry, this method: + 1. Reinitializes SQLAlchemy engines (disposes old pools, creates new ones) + 2. Calls parent class to handle refresh thread restart + + This is critical for fork-safety when using multi-worker servers like + Gunicorn with SQL Registry backends. + """ + # First reinitialize the database engines + self.reinitialize_engines() + + # Then call parent to handle refresh thread + super().on_worker_init() + def _sync_feast_metadata_to_projects_table(self): feast_metadata_projects: dict = {} projects_set: set = [] diff --git a/sdk/python/tests/unit/infra/offline_stores/test_duckdb_s3_config.py b/sdk/python/tests/unit/infra/offline_stores/test_duckdb_s3_config.py new file mode 100644 index 00000000000..e037e1ca529 --- /dev/null +++ b/sdk/python/tests/unit/infra/offline_stores/test_duckdb_s3_config.py @@ -0,0 +1,290 @@ +# Copyright 2024 The Feast Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for DuckDB offline store S3 configuration. + +These tests verify the S3 configuration functionality for DuckDB, +specifically the s3_url_style and related S3 settings that enable +compatibility with S3-compatible storage providers like MinIO. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from feast.infra.offline_stores.duckdb import ( + DuckDBOfflineStoreConfig, + _configure_duckdb_for_s3, + _is_s3_path, +) + + +class TestDuckDBOfflineStoreConfig: + """Tests for DuckDBOfflineStoreConfig S3 settings.""" + + def test_default_config_has_no_s3_settings(self): + """Test that default config has None for all S3 settings.""" + config = DuckDBOfflineStoreConfig() + + assert config.s3_url_style is None + assert config.s3_endpoint is None + assert config.s3_access_key_id is None + assert config.s3_secret_access_key is None + assert config.s3_region is None + assert config.s3_use_ssl is None + + def test_s3_url_style_path(self): + """Test that s3_url_style can be set to 'path'.""" + config = DuckDBOfflineStoreConfig(s3_url_style="path") + assert config.s3_url_style == "path" + + def test_s3_url_style_vhost(self): + """Test that s3_url_style can be set to 'vhost'.""" + config = DuckDBOfflineStoreConfig(s3_url_style="vhost") + assert config.s3_url_style == "vhost" + + def test_s3_url_style_invalid_raises_error(self): + """Test that invalid s3_url_style raises a validation error.""" + with pytest.raises(Exception): # Pydantic validation error + DuckDBOfflineStoreConfig(s3_url_style="invalid") + + def test_full_minio_config(self): + """Test a full MinIO-compatible configuration.""" + config = DuckDBOfflineStoreConfig( + s3_url_style="path", + s3_endpoint="localhost:9000", + s3_access_key_id="minioadmin", + s3_secret_access_key="minioadmin", + s3_region="us-east-1", + s3_use_ssl=False, + ) + + assert config.s3_url_style == "path" + assert config.s3_endpoint == "localhost:9000" + assert config.s3_access_key_id == "minioadmin" + assert config.s3_secret_access_key == "minioadmin" + assert config.s3_region == "us-east-1" + assert config.s3_use_ssl is False + + def test_partial_s3_config(self): + """Test that partial S3 config (only some fields) works.""" + config = DuckDBOfflineStoreConfig( + s3_url_style="path", + s3_endpoint="s3.custom-provider.com", + ) + + assert config.s3_url_style == "path" + assert config.s3_endpoint == "s3.custom-provider.com" + assert config.s3_access_key_id is None + assert config.s3_secret_access_key is None + + def test_s3_use_ssl_true(self): + """Test that s3_use_ssl can be set to True.""" + config = DuckDBOfflineStoreConfig(s3_use_ssl=True) + assert config.s3_use_ssl is True + + def test_s3_use_ssl_false(self): + """Test that s3_use_ssl can be set to False.""" + config = DuckDBOfflineStoreConfig(s3_use_ssl=False) + assert config.s3_use_ssl is False + + def test_config_with_staging_location(self): + """Test config with both S3 settings and staging location.""" + config = DuckDBOfflineStoreConfig( + staging_location="s3://my-bucket/staging", + staging_location_endpoint_override="http://localhost:9000", + s3_url_style="path", + s3_endpoint="localhost:9000", + ) + + assert config.staging_location == "s3://my-bucket/staging" + assert config.staging_location_endpoint_override == "http://localhost:9000" + assert config.s3_url_style == "path" + assert config.s3_endpoint == "localhost:9000" + + +class TestIsS3Path: + """Tests for _is_s3_path helper function.""" + + def test_s3_path(self): + """Test that s3:// paths are detected.""" + assert _is_s3_path("s3://bucket/key") is True + + def test_s3a_path(self): + """Test that s3a:// paths are detected.""" + assert _is_s3_path("s3a://bucket/key") is True + + def test_local_path(self): + """Test that local paths are not detected as S3.""" + assert _is_s3_path("/path/to/file.parquet") is False + + def test_http_path(self): + """Test that HTTP paths are not detected as S3.""" + assert _is_s3_path("http://example.com/file.parquet") is False + + def test_gs_path(self): + """Test that GCS paths are not detected as S3.""" + assert _is_s3_path("gs://bucket/key") is False + + +class TestConfigureDuckDBForS3: + """Tests for _configure_duckdb_for_s3 function.""" + + def setup_method(self): + """Reset the global _s3_configured flag before each test.""" + import feast.infra.offline_stores.duckdb as duckdb_module + + duckdb_module._s3_configured = False + + def test_no_config_does_nothing(self): + """Test that empty config doesn't call ibis.""" + config = DuckDBOfflineStoreConfig() + + with patch("feast.infra.offline_stores.duckdb.ibis") as mock_ibis: + _configure_duckdb_for_s3(config) + mock_ibis.get_backend.assert_not_called() + + def test_s3_url_style_configures_duckdb(self): + """Test that s3_url_style triggers DuckDB configuration.""" + config = DuckDBOfflineStoreConfig(s3_url_style="path") + + mock_con = MagicMock() + with patch( + "feast.infra.offline_stores.duckdb.ibis.get_backend", return_value=mock_con + ): + _configure_duckdb_for_s3(config) + + # Check that httpfs was installed and loaded + mock_con.raw_sql.assert_any_call("INSTALL httpfs;") + mock_con.raw_sql.assert_any_call("LOAD httpfs;") + + # Check that s3_url_style was set + mock_con.raw_sql.assert_any_call("SET s3_url_style='path';") + + def test_full_s3_config_sets_all_options(self): + """Test that all S3 options are configured correctly.""" + config = DuckDBOfflineStoreConfig( + s3_url_style="path", + s3_endpoint="localhost:9000", + s3_access_key_id="mykey", + s3_secret_access_key="mysecret", + s3_region="us-west-2", + s3_use_ssl=False, + ) + + mock_con = MagicMock() + with patch( + "feast.infra.offline_stores.duckdb.ibis.get_backend", return_value=mock_con + ): + _configure_duckdb_for_s3(config) + + # Verify all SQL commands were executed + calls = [str(call) for call in mock_con.raw_sql.call_args_list] + + assert any("s3_url_style='path'" in call for call in calls) + assert any("s3_endpoint='localhost:9000'" in call for call in calls) + assert any("s3_access_key_id='mykey'" in call for call in calls) + assert any("s3_secret_access_key='mysecret'" in call for call in calls) + assert any("s3_region='us-west-2'" in call for call in calls) + assert any("s3_use_ssl=false" in call for call in calls) + + def test_s3_use_ssl_true_sets_correctly(self): + """Test that s3_use_ssl=True is set correctly.""" + config = DuckDBOfflineStoreConfig(s3_use_ssl=True) + + mock_con = MagicMock() + with patch( + "feast.infra.offline_stores.duckdb.ibis.get_backend", return_value=mock_con + ): + _configure_duckdb_for_s3(config) + + mock_con.raw_sql.assert_any_call("SET s3_use_ssl=true;") + + def test_config_only_runs_once(self): + """Test that S3 configuration only runs once (cached).""" + config = DuckDBOfflineStoreConfig(s3_url_style="path") + + mock_con = MagicMock() + with patch( + "feast.infra.offline_stores.duckdb.ibis.get_backend", return_value=mock_con + ): + # First call should configure + _configure_duckdb_for_s3(config) + first_call_count = mock_con.raw_sql.call_count + + # Second call should not configure again + _configure_duckdb_for_s3(config) + assert mock_con.raw_sql.call_count == first_call_count + + def test_handles_ibis_error_gracefully(self): + """Test that ibis errors are handled gracefully.""" + config = DuckDBOfflineStoreConfig(s3_url_style="path") + + with patch( + "feast.infra.offline_stores.duckdb.ibis.get_backend", + side_effect=Exception("ibis error"), + ): + # Should not raise an exception + _configure_duckdb_for_s3(config) + + def test_only_endpoint_configured(self): + """Test configuration with only endpoint set.""" + config = DuckDBOfflineStoreConfig(s3_endpoint="minio.local:9000") + + mock_con = MagicMock() + with patch( + "feast.infra.offline_stores.duckdb.ibis.get_backend", return_value=mock_con + ): + _configure_duckdb_for_s3(config) + + # Should set endpoint but not url_style + calls = [str(call) for call in mock_con.raw_sql.call_args_list] + assert any("s3_endpoint='minio.local:9000'" in call for call in calls) + assert not any("s3_url_style" in call for call in calls) + + +class TestConfigFromYaml: + """Tests for loading DuckDB config from YAML-style dictionaries.""" + + def test_from_dict_with_s3_settings(self): + """Test creating config from dictionary (simulating YAML parsing).""" + config_dict = { + "type": "duckdb", + "s3_url_style": "path", + "s3_endpoint": "localhost:9000", + "s3_access_key_id": "minioadmin", + "s3_secret_access_key": "minioadmin", + "s3_region": "us-east-1", + "s3_use_ssl": False, + } + + config = DuckDBOfflineStoreConfig(**config_dict) + + assert config.type == "duckdb" + assert config.s3_url_style == "path" + assert config.s3_endpoint == "localhost:9000" + assert config.s3_access_key_id == "minioadmin" + assert config.s3_secret_access_key == "minioadmin" + assert config.s3_region == "us-east-1" + assert config.s3_use_ssl is False + + def test_from_dict_minimal(self): + """Test creating minimal config from dictionary.""" + config_dict = {"type": "duckdb"} + + config = DuckDBOfflineStoreConfig(**config_dict) + + assert config.type == "duckdb" + assert config.s3_url_style is None diff --git a/sdk/python/tests/unit/infra/registry/test_sql_registry.py b/sdk/python/tests/unit/infra/registry/test_sql_registry.py index 8e5154da47b..57388c4fbb0 100644 --- a/sdk/python/tests/unit/infra/registry/test_sql_registry.py +++ b/sdk/python/tests/unit/infra/registry/test_sql_registry.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import multiprocessing +import os +import sys import tempfile import pytest @@ -56,3 +59,201 @@ def test_sql_registry(sqlite_registry): sqlite_registry.delete_entity("test_entity", "test_project") with pytest.raises(Exception): sqlite_registry.get_entity("test_entity", "test_project") + + +def test_sql_registry_reinitialize_engines(): + """ + Test that reinitialize_engines() properly disposes and recreates engines. + + This is critical for fork-safety when using multi-worker servers. + """ + fd, registry_path = tempfile.mkstemp() + registry_config = SqlRegistryConfig( + registry_type="sql", + path=f"sqlite:///{registry_path}", + purge_feast_metadata=False, + ) + + registry = SqlRegistry(registry_config, "test_project", None) + + # Store original engine references + original_write_engine = registry.write_engine + original_read_engine = registry.read_engine + + # Apply an entity before reinitializing + entity = Entity( + name="test_entity", + description="Test entity before reinitialize", + ) + registry.apply_entity(entity, "test_project") + + # Reinitialize engines + registry.reinitialize_engines() + + # Verify engines are new instances + assert registry.write_engine is not original_write_engine + assert registry.read_engine is not original_read_engine + + # Verify the registry still works after reinitialization + retrieved_entity = registry.get_entity("test_entity", "test_project") + assert retrieved_entity.name == "test_entity" + + # Apply a new entity after reinitializing + entity2 = Entity( + name="test_entity_2", + description="Test entity after reinitialize", + ) + registry.apply_entity(entity2, "test_project") + retrieved_entity2 = registry.get_entity("test_entity_2", "test_project") + assert retrieved_entity2.name == "test_entity_2" + + registry.teardown() + + +def test_sql_registry_on_worker_init(): + """ + Test that on_worker_init() properly reinitializes the registry. + + This method should be called after a process fork to ensure + the registry has fresh database connections. + """ + fd, registry_path = tempfile.mkstemp() + registry_config = SqlRegistryConfig( + registry_type="sql", + path=f"sqlite:///{registry_path}", + purge_feast_metadata=False, + ) + + registry = SqlRegistry(registry_config, "test_project", None) + + # Store original engine reference + original_write_engine = registry.write_engine + + # Apply an entity before on_worker_init + entity = Entity( + name="test_entity", + description="Test entity before worker init", + ) + registry.apply_entity(entity, "test_project") + + # Call on_worker_init (simulates what happens after fork) + registry.on_worker_init() + + # Verify engine was recreated + assert registry.write_engine is not original_write_engine + + # Verify the registry still works + retrieved_entity = registry.get_entity("test_entity", "test_project") + assert retrieved_entity.name == "test_entity" + + registry.teardown() + + +def test_sql_registry_with_separate_read_write_engines(): + """ + Test reinitialize_engines with separate read and write paths. + """ + fd1, write_path = tempfile.mkstemp() + fd2, read_path = tempfile.mkstemp() + + # Use the same path for both (SQLite doesn't support true read replicas, + # but this tests the code path) + registry_config = SqlRegistryConfig( + registry_type="sql", + path=f"sqlite:///{write_path}", + read_path=f"sqlite:///{write_path}", # Same path to share data + purge_feast_metadata=False, + ) + + registry = SqlRegistry(registry_config, "test_project", None) + + # When read_path is specified, read_engine should be different from write_engine + assert registry.read_engine is not registry.write_engine + + original_write_engine = registry.write_engine + original_read_engine = registry.read_engine + + # Apply entity + entity = Entity(name="test_entity", description="Test") + registry.apply_entity(entity, "test_project") + + # Reinitialize + registry.reinitialize_engines() + + # Both engines should be new + assert registry.write_engine is not original_write_engine + assert registry.read_engine is not original_read_engine + assert registry.read_engine is not registry.write_engine + + # Verify still works + retrieved = registry.get_entity("test_entity", "test_project") + assert retrieved.name == "test_entity" + + registry.teardown() + + +@pytest.mark.skipif( + sys.platform == "win32", reason="Fork not available on Windows" +) +def test_sql_registry_fork_safety(): + """ + Test that SqlRegistry works correctly after a process fork. + + This test simulates what happens when Gunicorn forks worker processes. + Each worker should be able to use the registry after calling on_worker_init(). + """ + fd, registry_path = tempfile.mkstemp() + registry_config = SqlRegistryConfig( + registry_type="sql", + path=f"sqlite:///{registry_path}", + purge_feast_metadata=False, + ) + + registry = SqlRegistry(registry_config, "test_project", None) + + # Apply an entity in the parent process + entity = Entity( + name="parent_entity", + description="Created in parent process", + ) + registry.apply_entity(entity, "test_project") + + def child_process_work(result_queue): + """Work done in child process after fork.""" + try: + # This simulates what the post_fork hook does + registry.on_worker_init() + + # Try to read the entity created by parent + retrieved = registry.get_entity("parent_entity", "test_project") + assert retrieved.name == "parent_entity" + + # Try to create a new entity in the child + child_entity = Entity( + name=f"child_entity_{os.getpid()}", + description="Created in child process", + ) + registry.apply_entity(child_entity, "test_project") + + # Verify we can read it back + retrieved_child = registry.get_entity( + f"child_entity_{os.getpid()}", "test_project" + ) + assert retrieved_child.name == f"child_entity_{os.getpid()}" + + result_queue.put(("success", None)) + except Exception as e: + result_queue.put(("error", str(e))) + + # Use multiprocessing to simulate fork + result_queue = multiprocessing.Queue() + child = multiprocessing.Process(target=child_process_work, args=(result_queue,)) + child.start() + child.join(timeout=30) + + # Check result + assert not result_queue.empty(), "Child process did not return a result" + status, error = result_queue.get() + assert status == "success", f"Child process failed: {error}" + + registry.teardown() diff --git a/sdk/python/tests/unit/test_feature_server.py b/sdk/python/tests/unit/test_feature_server.py index e3fd0387fb9..aed89bbab9a 100644 --- a/sdk/python/tests/unit/test_feature_server.py +++ b/sdk/python/tests/unit/test_feature_server.py @@ -408,3 +408,77 @@ def load_artifacts(app: FastAPI): assert lookup_tables["sentiment_labels"]["LABEL_1"] == "neutral" assert lookup_tables["sentiment_labels"]["LABEL_2"] == "positive" assert lookup_tables["emoji_sentiment"]["😊"] == "positive" + + +# Fork-safety tests for multi-worker Gunicorn deployments +def test_feast_serve_application_has_post_fork_hook(): + """Test that FeastServeApplication properly configures the post_fork hook.""" + import sys + + if sys.platform == "win32": + pytest.skip("Gunicorn not available on Windows") + + from unittest.mock import MagicMock, patch + + # Import after platform check since gunicorn doesn't work on Windows + from feast.feature_server import FeastServeApplication + + mock_store = MagicMock() + mock_store.registry = MagicMock() + mock_store.config.auth_config.type = "no_auth" + + with patch("feast.feature_server.get_app") as mock_get_app: + mock_get_app.return_value = MagicMock() + + app = FeastServeApplication( + store=mock_store, + registry_ttl_sec=60, + bind="127.0.0.1:8000", + workers=4, + ) + + # Verify the config has been loaded + app.load_config() + + # Check that post_fork hook is set + assert app.cfg.post_fork is not None + assert callable(app.cfg.post_fork) + + +def test_feast_serve_application_post_fork_calls_on_worker_init(): + """Test that the post_fork hook calls registry.on_worker_init().""" + import sys + + if sys.platform == "win32": + pytest.skip("Gunicorn not available on Windows") + + from unittest.mock import MagicMock, patch + + from feast.feature_server import FeastServeApplication + + mock_store = MagicMock() + mock_store.registry = MagicMock() + mock_store.registry.on_worker_init = MagicMock() + mock_store.config.auth_config.type = "no_auth" + + with patch("feast.feature_server.get_app") as mock_get_app: + mock_get_app.return_value = MagicMock() + + app = FeastServeApplication( + store=mock_store, + registry_ttl_sec=60, + bind="127.0.0.1:8000", + workers=4, + ) + + app.load_config() + + # Simulate the post_fork hook being called + mock_server = MagicMock() + mock_worker = MagicMock() + mock_worker.pid = 12345 + + app._post_fork_hook(mock_server, mock_worker) + + # Verify on_worker_init was called + mock_store.registry.on_worker_init.assert_called_once()