diff --git a/sdk/python/feast/api/registry/rest/metrics.py b/sdk/python/feast/api/registry/rest/metrics.py index 095feec51cb..32253e85aca 100644 --- a/sdk/python/feast/api/registry/rest/metrics.py +++ b/sdk/python/feast/api/registry/rest/metrics.py @@ -1,5 +1,4 @@ import json -import logging from typing import Dict, List, Optional from fastapi import APIRouter, Depends, HTTPException, Query, Request @@ -57,7 +56,6 @@ class PopularTagsResponse(BaseModel): def get_metrics_router(grpc_handler, server=None) -> APIRouter: - logger = logging.getLogger(__name__) router = APIRouter() @router.get("/metrics/resource_counts", tags=["Metrics"]) @@ -321,20 +319,43 @@ async def recently_visited( user = getattr(request.state, "user", None) if not user: user = "anonymous" - project_val = project or (server.store.project if server else None) key = f"recently_visited_{user}" - logger.info( - f"[/metrics/recently_visited] Project: {project_val}, Key: {key}, Object: {object_type}" - ) - try: - visits_json = ( - server.registry.get_project_metadata(project_val, key) - if server - else None - ) - visits = json.loads(visits_json) if visits_json else [] - except Exception: - visits = [] + visits = [] + if project: + try: + visits_json = ( + server.registry.get_project_metadata(project, key) + if server + else None + ) + visits = json.loads(visits_json) if visits_json else [] + except Exception: + visits = [] + else: + try: + if server: + projects_resp = grpc_call( + grpc_handler.ListProjects, + RegistryServer_pb2.ListProjectsRequest(allow_cache=True), + ) + all_projects = [ + p["spec"]["name"] for p in projects_resp.get("projects", []) + ] + for project_name in all_projects: + try: + visits_json = server.registry.get_project_metadata( + project_name, key + ) + if visits_json: + project_visits = json.loads(visits_json) + visits.extend(project_visits) + except Exception: + continue + visits = sorted( + visits, key=lambda x: x.get("timestamp", ""), reverse=True + ) + except Exception: + visits = [] if object_type: visits = [v for v in visits if v.get("object") == object_type] diff --git a/sdk/python/feast/infra/compute_engines/spark/utils.py b/sdk/python/feast/infra/compute_engines/spark/utils.py index a03cfdb12d1..4e429f8e075 100644 --- a/sdk/python/feast/infra/compute_engines/spark/utils.py +++ b/sdk/python/feast/infra/compute_engines/spark/utils.py @@ -21,9 +21,6 @@ def get_or_create_new_spark_session( conf=SparkConf().setAll([(k, v) for k, v in spark_config.items()]) ) - spark_builder = spark_builder.config("spark.driver.host", "127.0.0.1") - spark_builder = spark_builder.config("spark.driver.bindAddress", "127.0.0.1") - spark_session = spark_builder.getOrCreate() spark_session.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true") return spark_session diff --git a/sdk/python/feast/infra/registry/caching_registry.py b/sdk/python/feast/infra/registry/caching_registry.py index ab7944585b6..ce346272af9 100644 --- a/sdk/python/feast/infra/registry/caching_registry.py +++ b/sdk/python/feast/infra/registry/caching_registry.py @@ -425,9 +425,6 @@ def list_projects( return self._list_projects(tags) def refresh(self, project: Optional[str] = None): - if self._refresh_lock.locked(): - logger.debug("Skipping refresh if already in progress") - return try: self.cached_registry_proto = self.proto() self.cached_registry_proto_created = _utc_now() @@ -436,43 +433,47 @@ def refresh(self, project: Optional[str] = None): def _refresh_cached_registry_if_necessary(self): if self.cache_mode == "sync": - # Try acquiring the lock without blocking - if not self._refresh_lock.acquire(blocking=False): - logger.debug( - "Skipping refresh if lock is already held by another thread" - ) - return - try: - if self.cached_registry_proto == RegistryProto(): - # Avoids the need to refresh the registry when cache is not populated yet - # Specially during the __init__ phase - # proto() will populate the cache with project metadata if no objects are registered - expired = False - else: - expired = ( - self.cached_registry_proto is None - or self.cached_registry_proto_created is None - ) or ( - self.cached_registry_proto_ttl.total_seconds() - > 0 # 0 ttl means infinity - and ( - _utc_now() - > ( - self.cached_registry_proto_created - + self.cached_registry_proto_ttl - ) - ) + + def is_cache_expired(): + if ( + self.cached_registry_proto is None + or self.cached_registry_proto == RegistryProto() + ): + return True + + # Cache is expired if creation time is None + if ( + not hasattr(self, "cached_registry_proto_created") + or self.cached_registry_proto_created is None + ): + return True + + # Cache is expired if TTL > 0 and current time exceeds creation + TTL + if self.cached_registry_proto_ttl.total_seconds() > 0 and _utc_now() > ( + self.cached_registry_proto_created + self.cached_registry_proto_ttl + ): + return True + + return False + + if is_cache_expired(): + if not self._refresh_lock.acquire(blocking=False): + logger.debug( + "Skipping refresh if lock is already held by another thread" + ) + return + try: + logger.info( + f"Registry cache expired(ttl: {self.cached_registry_proto_ttl.total_seconds()} seconds), so refreshing" ) - if expired: - logger.debug("Registry cache expired, so refreshing") self.refresh() - except Exception as e: - logger.debug( - f"Error in _refresh_cached_registry_if_necessary: {e}", - exc_info=True, - ) - finally: - self._refresh_lock.release() # Always release the lock safely + except Exception as e: + logger.debug( + f"Error in _refresh_cached_registry_if_necessary: {e}", + exc_info=True, + ) + finally: + self._refresh_lock.release() def _start_thread_async_refresh(self, cache_ttl_seconds): self.refresh() diff --git a/sdk/python/feast/infra/registry/sql.py b/sdk/python/feast/infra/registry/sql.py index c2e20005a65..360a844b0be 100644 --- a/sdk/python/feast/infra/registry/sql.py +++ b/sdk/python/feast/infra/registry/sql.py @@ -15,6 +15,7 @@ MetaData, String, Table, + Text, create_engine, delete, insert, @@ -209,7 +210,7 @@ class FeastMetadataKeys(Enum): metadata, Column("project_id", String(255), primary_key=True), Column("metadata_key", String(50), primary_key=True), - Column("metadata_value", String(50), nullable=False), + Column("metadata_value", Text, nullable=False), Column("last_updated_timestamp", BigInteger, nullable=False), ) @@ -326,6 +327,7 @@ def teardown(self): entities, data_sources, feature_views, + stream_feature_views, feature_services, on_demand_feature_views, saved_datasets, @@ -845,18 +847,6 @@ def process_project(project: Project): project_name = project.name last_updated_timestamp = project.last_updated_timestamp - try: - cached_project = self.get_project(project_name, True) - except ProjectObjectNotFoundException: - cached_project = None - - allow_cache = False - - if cached_project is not None: - allow_cache = ( - last_updated_timestamp <= cached_project.last_updated_timestamp - ) - r.projects.extend([project.to_proto()]) last_updated_timestamps.append(last_updated_timestamp) @@ -871,7 +861,7 @@ def process_project(project: Project): (self.list_validation_references, r.validation_references), (self.list_permissions, r.permissions), ]: - objs: List[Any] = lister(project_name, allow_cache) # type: ignore + objs: List[Any] = lister(project_name, allow_cache=False) # type: ignore if objs: obj_protos = [obj.to_proto() for obj in objs] for obj_proto in obj_protos: @@ -1020,6 +1010,9 @@ def _apply_object( if not self.purge_feast_metadata: self._set_last_updated_metadata(update_datetime, project, conn) + if self.cache_mode == "sync": + self.refresh() + def _maybe_init_project_metadata(self, project): # Initialize project metadata if needed with self.write_engine.begin() as conn: @@ -1062,6 +1055,8 @@ def _delete_object( if not self.purge_feast_metadata: self._set_last_updated_metadata(_utc_now(), project, conn) + if self.cache_mode == "sync": + self.refresh() return rows.rowcount def _get_object( diff --git a/sdk/python/tests/integration/registration/test_universal_registry.py b/sdk/python/tests/integration/registration/test_universal_registry.py index ea0f3ddcd67..eb663d8565a 100644 --- a/sdk/python/tests/integration/registration/test_universal_registry.py +++ b/sdk/python/tests/integration/registration/test_universal_registry.py @@ -273,6 +273,8 @@ def sqlite_registry(): registry_config = SqlRegistryConfig( registry_type="sql", path="sqlite://", + cache_ttl_seconds=2, + cache_mode="sync", ) yield SqlRegistry(registry_config, "project", None) @@ -1156,11 +1158,10 @@ def test_registry_cache(test_registry): registry_data_sources_cached = test_registry.list_data_sources( project, allow_cache=True ) - # Not refreshed cache, so cache miss - assert len(registry_feature_views_cached) == 0 - assert len(registry_data_sources_cached) == 0 + assert len(registry_feature_views_cached) == 1 + assert len(registry_data_sources_cached) == 1 + test_registry.refresh(project) - # Now objects exist registry_feature_views_cached = test_registry.list_feature_views( project, allow_cache=True, tags=fv1.tags ) diff --git a/sdk/python/tests/unit/infra/registry/test_registry.py b/sdk/python/tests/unit/infra/registry/test_registry.py index f103925f1c1..51bb7d28e32 100644 --- a/sdk/python/tests/unit/infra/registry/test_registry.py +++ b/sdk/python/tests/unit/infra/registry/test_registry.py @@ -4,6 +4,7 @@ import pytest from feast.infra.registry.caching_registry import CachingRegistry +from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto class TestCachingRegistry(CachingRegistry): @@ -188,6 +189,63 @@ def test_cache_expiry_triggers_refresh(registry): mock_refresh.assert_called_once() +def test_empty_cache_refresh_with_ttl(registry): + """Test that empty cache is refreshed when TTL > 0""" + # Set up empty cache with TTL > 0 + registry.cached_registry_proto = RegistryProto() + registry.cached_registry_proto_created = datetime.now(timezone.utc) + registry.cached_registry_proto_ttl = timedelta(seconds=10) # TTL > 0 + + # Mock refresh to check if it's called + with patch.object( + CachingRegistry, "refresh", wraps=registry.refresh + ) as mock_refresh: + registry._refresh_cached_registry_if_necessary() + # Should refresh because cache is empty and TTL > 0 + mock_refresh.assert_called_once() + + +def test_concurrent_cache_refresh_race_condition(registry): + """Test that concurrent requests don't skip cache refresh when cache is expired""" + import threading + import time + + # Set up expired cache + registry.cached_registry_proto = RegistryProto() + registry.cached_registry_proto_created = datetime.now(timezone.utc) - timedelta( + seconds=5 + ) + registry.cached_registry_proto_ttl = timedelta( + seconds=2 + ) # TTL = 2 seconds, cache is expired + + refresh_calls = [] + + def mock_refresh(): + refresh_calls.append(threading.current_thread().ident) + time.sleep(0.1) # Simulate refresh work + + # Mock the refresh method to track calls + with patch.object(registry, "refresh", side_effect=mock_refresh): + # Simulate concurrent requests + threads = [] + for i in range(3): + thread = threading.Thread( + target=registry._refresh_cached_registry_if_necessary + ) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # At least one thread should have called refresh (the first one to acquire the lock) + assert len(refresh_calls) >= 1, ( + "At least one thread should have refreshed the cache" + ) + + def test_skip_refresh_if_lock_held(registry): """Test that refresh is skipped if the lock is already held by another thread""" registry.cached_registry_proto = "some_cached_data"