From 214cc3100add5a269b111fff390871cfd68d6d37 Mon Sep 17 00:00:00 2001 From: Bhargav Dodla Date: Mon, 19 Aug 2024 17:24:31 -0700 Subject: [PATCH 1/3] fix: Optimize SQL Registry proto() method Signed-off-by: Bhargav Dodla --- protos/feast/core/Registry.proto | 1 + sdk/python/feast/feature_view.py | 2 +- .../feast/infra/registry/caching_registry.py | 24 ++- sdk/python/feast/infra/registry/sql.py | 161 ++++++++++++++---- sdk/python/feast/project_metadata.py | 24 ++- 5 files changed, 168 insertions(+), 44 deletions(-) diff --git a/protos/feast/core/Registry.proto b/protos/feast/core/Registry.proto index 0c3f8a53f94..b9352a68f36 100644 --- a/protos/feast/core/Registry.proto +++ b/protos/feast/core/Registry.proto @@ -56,4 +56,5 @@ message Registry { message ProjectMetadata { string project = 1; string project_uuid = 2; + google.protobuf.Timestamp last_updated_timestamp = 3; } diff --git a/sdk/python/feast/feature_view.py b/sdk/python/feast/feature_view.py index 1a85a4b90c0..dd01078e206 100644 --- a/sdk/python/feast/feature_view.py +++ b/sdk/python/feast/feature_view.py @@ -423,7 +423,7 @@ def from_proto(cls, feature_view_proto: FeatureViewProto): if len(feature_view.entities) != len(feature_view.entity_columns): warnings.warn( - f"There are some mismatches in your feature view's registered entities. Please check if you have applied your entities correctly." + f"There are some mismatches in your feature view: {feature_view.name} registered entities. Please check if you have applied your entities correctly." f"Entities: {feature_view.entities} vs Entity Columns: {feature_view.entity_columns}" ) diff --git a/sdk/python/feast/infra/registry/caching_registry.py b/sdk/python/feast/infra/registry/caching_registry.py index 298639028d5..d2d39b97bb9 100644 --- a/sdk/python/feast/infra/registry/caching_registry.py +++ b/sdk/python/feast/infra/registry/caching_registry.py @@ -15,6 +15,7 @@ from feast.infra.registry.base_registry import BaseRegistry from feast.on_demand_feature_view import OnDemandFeatureView from feast.project_metadata import ProjectMetadata +from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto from feast.saved_dataset import SavedDataset, ValidationReference from feast.stream_feature_view import StreamFeatureView from feast.utils import _utc_now @@ -24,14 +25,14 @@ class CachingRegistry(BaseRegistry): def __init__(self, project: str, cache_ttl_seconds: int, cache_mode: str): + self.cache_mode = cache_mode + self.cached_registry_proto = RegistryProto() self.cached_registry_proto = self.proto() - proto_registry_utils.init_project_metadata(self.cached_registry_proto, project) self.cached_registry_proto_created = _utc_now() self._refresh_lock = Lock() self.cached_registry_proto_ttl = timedelta( seconds=cache_ttl_seconds if cache_ttl_seconds is not None else 0 ) - self.cache_mode = cache_mode if cache_mode == "thread": self._start_thread_async_refresh(cache_ttl_seconds) atexit.register(self._exit_handler) @@ -304,6 +305,25 @@ def list_project_metadata( ) return self._list_project_metadata(project) + @abstractmethod + def _get_project_metadata(self, project: str) -> Optional[ProjectMetadata]: + pass + + # TODO: get_project_metadata() needs to be added to BaseRegistry class + def get_project_metadata( + self, project: str, allow_cache: bool = False + ) -> Optional[ProjectMetadata]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + project_metadata_proto = proto_registry_utils.get_project_metadata( + self.cached_registry_proto, project + ) + if project_metadata_proto is None: + return None + else: + return ProjectMetadata.from_proto(project_metadata_proto) + return self._get_project_metadata(project) + @abstractmethod def _get_infra(self, project: str) -> Infra: pass diff --git a/sdk/python/feast/infra/registry/sql.py b/sdk/python/feast/infra/registry/sql.py index a2b16a3a091..5f887d194e1 100644 --- a/sdk/python/feast/infra/registry/sql.py +++ b/sdk/python/feast/infra/registry/sql.py @@ -1,14 +1,16 @@ import logging import uuid +from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timezone from enum import Enum from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, Callable, Dict, List, Optional, Union from pydantic import StrictStr from sqlalchemy import ( # type: ignore BigInteger, Column, + Index, LargeBinary, MetaData, String, @@ -73,6 +75,8 @@ Column("entity_proto", LargeBinary, nullable=False), ) +Index("idx_entities_project_id", entities.c.project_id) + data_sources = Table( "data_sources", metadata, @@ -82,6 +86,8 @@ Column("data_source_proto", LargeBinary, nullable=False), ) +Index("idx_data_sources_project_id", data_sources.c.project_id) + feature_views = Table( "feature_views", metadata, @@ -93,6 +99,8 @@ Column("user_metadata", LargeBinary, nullable=True), ) +Index("idx_feature_views_project_id", feature_views.c.project_id) + stream_feature_views = Table( "stream_feature_views", metadata, @@ -103,6 +111,8 @@ Column("user_metadata", LargeBinary, nullable=True), ) +Index("idx_stream_feature_views_project_id", stream_feature_views.c.project_id) + on_demand_feature_views = Table( "on_demand_feature_views", metadata, @@ -113,6 +123,8 @@ Column("user_metadata", LargeBinary, nullable=True), ) +Index("idx_on_demand_feature_views_project_id", on_demand_feature_views.c.project_id) + feature_services = Table( "feature_services", metadata, @@ -122,6 +134,8 @@ Column("feature_service_proto", LargeBinary, nullable=False), ) +Index("idx_feature_services_project_id", feature_services.c.project_id) + saved_datasets = Table( "saved_datasets", metadata, @@ -131,6 +145,8 @@ Column("saved_dataset_proto", LargeBinary, nullable=False), ) +Index("idx_saved_datasets_project_id", saved_datasets.c.project_id) + validation_references = Table( "validation_references", metadata, @@ -140,6 +156,8 @@ Column("validation_reference_proto", LargeBinary, nullable=False), ) +Index("idx_validation_references_project_id", validation_references.c.project_id) + managed_infra = Table( "managed_infra", metadata, @@ -149,6 +167,8 @@ Column("infra_proto", LargeBinary, nullable=False), ) +Index("idx_managed_infra_project_id", managed_infra.c.project_id) + class FeastMetadataKeys(Enum): LAST_UPDATED_TIMESTAMP = "last_updated_timestamp" @@ -164,6 +184,12 @@ class FeastMetadataKeys(Enum): Column("last_updated_timestamp", BigInteger, nullable=False), ) +Index( + "idx_feast_metadata_project_id_metadata_key", + feast_metadata.c.project_id, + feast_metadata.c.metadata_key, +) + logger = logging.getLogger(__name__) @@ -179,6 +205,10 @@ class SqlRegistryConfig(RegistryConfig): """ Dict[str, Any]: Extra arguments to pass to SQLAlchemy.create_engine. """ +# Number of workers in ThreadPoolExecutor +MAX_WORKERS = 5 + + class SqlRegistry(CachingRegistry): def __init__( self, @@ -192,6 +222,9 @@ def __init__( registry_config.path, **registry_config.sqlalchemy_config_kwargs ) metadata.create_all(self.engine) + + self._maybe_init_project_metadata(project) + super().__init__( project=project, cache_ttl_seconds=registry_config.cache_ttl_seconds, @@ -482,7 +515,17 @@ def _list_project_metadata(self, project: str) -> List[ProjectMetadata]: == FeastMetadataKeys.PROJECT_UUID.value ): project_metadata.project_uuid = row._mapping["metadata_value"] - break + + if ( + row._mapping["metadata_key"] + == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value + ): + project_metadata.last_updated_timestamp = ( + datetime.fromtimestamp( + int(row._mapping["metadata_value"]), tz=timezone.utc + ) + ) + # TODO(adchia): Add other project metadata in a structured way return [project_metadata] return [] @@ -654,8 +697,23 @@ def get_user_metadata( def proto(self) -> RegistryProto: r = RegistryProto() last_updated_timestamps = [] - projects = self._get_all_projects() - for project in projects: + + def process_project(project_metadata: ProjectMetadata): + nonlocal r, last_updated_timestamps + project = project_metadata.project_name + last_updated_timestamp = project_metadata.last_updated_timestamp + + cached_project_metadata = self.get_project_metadata(project, True) + allow_cache = False + + if cached_project_metadata is not None: + allow_cache = ( + last_updated_timestamp + <= cached_project_metadata.last_updated_timestamp + ) + + r.project_metadata.extend([project_metadata.to_proto()]) + last_updated_timestamps.append(last_updated_timestamp) for lister, registry_proto_field in [ (self.list_entities, r.entities), (self.list_feature_views, r.feature_views), @@ -665,9 +723,8 @@ def proto(self) -> RegistryProto: (self.list_feature_services, r.feature_services), (self.list_saved_datasets, r.saved_datasets), (self.list_validation_references, r.validation_references), - (self.list_project_metadata, r.project_metadata), ]: - objs: List[Any] = lister(project) # type: ignore + objs: List[Any] = lister(project, allow_cache) # type: ignore if objs: obj_protos = [obj.to_proto() for obj in objs] for obj_proto in obj_protos: @@ -680,7 +737,13 @@ def proto(self) -> RegistryProto: # This is suuuper jank. Because of https://github.com/feast-dev/feast/issues/2783, # the registry proto only has a single infra field, which we're currently setting as the "last" project. r.infra.CopyFrom(self.get_infra(project).to_proto()) - last_updated_timestamps.append(self._get_last_updated_metadata(project)) + + project_metadata_list = self.get_all_projects() + + with ThreadPoolExecutor( + max_workers=MAX_WORKERS + ) as executor: # Adjust max_workers as needed. Defaults to 5 + executor.map(process_project, project_metadata_list) if last_updated_timestamps: r.last_updated.FromDatetime(max(last_updated_timestamps)) @@ -700,8 +763,6 @@ def _apply_object( proto_field_name: str, name: Optional[str] = None, ): - self._maybe_init_project_metadata(project) - name = name or (obj.name if hasattr(obj, "name") else None) assert name, f"name needs to be provided for {obj}" @@ -821,8 +882,6 @@ def _get_object( proto_field_name: str, not_found_exception: Optional[Callable], ): - self._maybe_init_project_metadata(project) - with self.engine.begin() as conn: stmt = select(table).where( getattr(table.c, id_field_name) == name, table.c.project_id == project @@ -845,7 +904,6 @@ def _list_objects( proto_field_name: str, tags: Optional[dict[str, str]] = None, ): - self._maybe_init_project_metadata(project) with self.engine.begin() as conn: stmt = select(table).where(table.c.project_id == project) rows = conn.execute(stmt).all() @@ -894,33 +952,64 @@ def _set_last_updated_metadata(self, last_updated: datetime, project: str): ) conn.execute(insert_stmt) - def _get_last_updated_metadata(self, project: str): + def get_all_projects(self) -> List[ProjectMetadata]: + """ + Returns all projects with metadata + """ + project_metadata_dict: Dict[str, ProjectMetadata] = {} with self.engine.begin() as conn: - stmt = select(feast_metadata).where( - feast_metadata.c.metadata_key - == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value, - feast_metadata.c.project_id == project, - ) - row = conn.execute(stmt).first() - if not row: - return None - update_time = int(row._mapping["last_updated_timestamp"]) + stmt = select(feast_metadata) + rows = conn.execute(stmt).all() + if rows: + for row in rows: + project_id = row._mapping["project_id"] + metadata_key = row._mapping["metadata_key"] + metadata_value = row._mapping["metadata_value"] + + if project_id not in project_metadata_dict: + project_metadata_dict[project_id] = ProjectMetadata( + project_name=project_id + ) + + project_metadata_model: ProjectMetadata = project_metadata_dict[ + project_id + ] + if metadata_key == FeastMetadataKeys.PROJECT_UUID.value: + project_metadata_model.project_uuid = metadata_value - return datetime.fromtimestamp(update_time, tz=timezone.utc) + if metadata_key == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value: + project_metadata_model.last_updated_timestamp = ( + datetime.fromtimestamp(int(metadata_value), tz=timezone.utc) + ) + return list(project_metadata_dict.values()) - def _get_all_projects(self) -> Set[str]: - projects = set() + def _get_project_metadata( + self, + project: str, + ) -> Optional[ProjectMetadata]: + """ + Returns given project metadata. + """ with self.engine.begin() as conn: - for table in { - entities, - data_sources, - feature_views, - on_demand_feature_views, - stream_feature_views, - }: - stmt = select(table) - rows = conn.execute(stmt).all() + stmt = select(feast_metadata).where( + feast_metadata.c.project_id == project, + ) + rows = conn.execute(stmt).all() + if rows: + project_metadata: ProjectMetadata = ProjectMetadata( + project_name=project + ) for row in rows: - projects.add(row._mapping["project_id"]) + metadata_key = row._mapping["metadata_key"] + metadata_value = row._mapping["metadata_value"] + + if metadata_key == FeastMetadataKeys.PROJECT_UUID.value: + project_metadata.project_uuid = metadata_value - return projects + if metadata_key == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value: + project_metadata.last_updated_timestamp = ( + datetime.fromtimestamp(int(metadata_value), tz=timezone.utc) + ) + return project_metadata + else: + return None diff --git a/sdk/python/feast/project_metadata.py b/sdk/python/feast/project_metadata.py index 64488a03629..10887fce3af 100644 --- a/sdk/python/feast/project_metadata.py +++ b/sdk/python/feast/project_metadata.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import uuid +from datetime import datetime, timezone from typing import Optional from google.protobuf.json_format import MessageToJson @@ -28,16 +29,19 @@ class ProjectMetadata: Attributes: project_name: The registry-scoped unique name of the project. project_uuid: The UUID for this project + last_updated_timestamp: Last updated timestamp for this project """ project_name: str project_uuid: str + last_updated_timestamp: datetime def __init__( self, *args, project_name: Optional[str] = None, project_uuid: Optional[str] = None, + last_updated_timestamp: datetime = datetime.fromtimestamp(1, tz=timezone.utc), ): """ Creates an Project metadata object. @@ -54,9 +58,10 @@ def __init__( self.project_name = project_name self.project_uuid = project_uuid or f"{uuid.uuid4()}" + self.last_updated_timestamp = last_updated_timestamp def __hash__(self) -> int: - return hash((self.project_name, self.project_uuid)) + return hash((self.project_name, self.project_uuid, self.last_updated_timestamp)) def __eq__(self, other): if not isinstance(other, ProjectMetadata): @@ -67,6 +72,7 @@ def __eq__(self, other): if ( self.project_name != other.project_name or self.project_uuid != other.project_uuid + or self.last_updated_timestamp != other.last_updated_timestamp ): return False @@ -89,12 +95,15 @@ def from_proto(cls, project_metadata_proto: ProjectMetadataProto): Returns: A ProjectMetadata object based on the protobuf. """ - entity = cls( + project_metadata = cls( project_name=project_metadata_proto.project, project_uuid=project_metadata_proto.project_uuid, + last_updated_timestamp=project_metadata_proto.last_updated_timestamp.ToDatetime( + tzinfo=timezone.utc + ), ) - return entity + return project_metadata def to_proto(self) -> ProjectMetadataProto: """ @@ -104,6 +113,11 @@ def to_proto(self) -> ProjectMetadataProto: An ProjectMetadataProto protobuf. """ - return ProjectMetadataProto( - project=self.project_name, project_uuid=self.project_uuid + project_metadata_proto = ProjectMetadataProto( + project=self.project_name, + project_uuid=self.project_uuid, ) + project_metadata_proto.last_updated_timestamp.FromDatetime( + self.last_updated_timestamp + ) + return project_metadata_proto From 3a0dee8e9de98c44572872af5f352a1fa038c976 Mon Sep 17 00:00:00 2001 From: Bhargav Dodla Date: Mon, 19 Aug 2024 18:33:18 -0700 Subject: [PATCH 2/3] fix: Fixed linting issue Signed-off-by: Bhargav Dodla --- sdk/python/feast/project_metadata.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/python/feast/project_metadata.py b/sdk/python/feast/project_metadata.py index 10887fce3af..b6ccc8c64b6 100644 --- a/sdk/python/feast/project_metadata.py +++ b/sdk/python/feast/project_metadata.py @@ -98,8 +98,8 @@ def from_proto(cls, project_metadata_proto: ProjectMetadataProto): project_metadata = cls( project_name=project_metadata_proto.project, project_uuid=project_metadata_proto.project_uuid, - last_updated_timestamp=project_metadata_proto.last_updated_timestamp.ToDatetime( - tzinfo=timezone.utc + last_updated_timestamp=project_metadata_proto.last_updated_timestamp.ToDatetime().astimezone( + tz=timezone.utc ), ) From 8de470e5d410c4b98ab1958c92b325e32b6956df Mon Sep 17 00:00:00 2001 From: Bhargav Dodla Date: Tue, 20 Aug 2024 16:15:18 -0700 Subject: [PATCH 3/3] fix: Added tests to the optimization changes Signed-off-by: Bhargav Dodla --- .../feast/infra/registry/caching_registry.py | 44 +++---- sdk/python/feast/infra/registry/sql.py | 36 +++-- sdk/python/feast/project_metadata.py | 4 +- sdk/python/feast/repo_config.py | 3 + .../registration/test_universal_registry.py | 124 +++++++++++++++++- .../tests/unit/test_on_demand_feature_view.py | 9 +- .../tests/unit/test_project_metadata.py | 99 ++++++++++++++ 7 files changed, 275 insertions(+), 44 deletions(-) create mode 100644 sdk/python/tests/unit/test_project_metadata.py diff --git a/sdk/python/feast/infra/registry/caching_registry.py b/sdk/python/feast/infra/registry/caching_registry.py index d2d39b97bb9..40859ea2678 100644 --- a/sdk/python/feast/infra/registry/caching_registry.py +++ b/sdk/python/feast/infra/registry/caching_registry.py @@ -27,12 +27,12 @@ class CachingRegistry(BaseRegistry): def __init__(self, project: str, cache_ttl_seconds: int, cache_mode: str): self.cache_mode = cache_mode self.cached_registry_proto = RegistryProto() - self.cached_registry_proto = self.proto() - self.cached_registry_proto_created = _utc_now() self._refresh_lock = Lock() self.cached_registry_proto_ttl = timedelta( seconds=cache_ttl_seconds if cache_ttl_seconds is not None else 0 ) + self.cached_registry_proto = self.proto() + self.cached_registry_proto_created = _utc_now() if cache_mode == "thread": self._start_thread_async_refresh(cache_ttl_seconds) atexit.register(self._exit_handler) @@ -332,34 +332,32 @@ def get_infra(self, project: str, allow_cache: bool = False) -> Infra: return self._get_infra(project) def refresh(self, project: Optional[str] = None): - if project: - project_metadata = proto_registry_utils.get_project_metadata( - registry_proto=self.cached_registry_proto, project=project - ) - if not project_metadata: - proto_registry_utils.init_project_metadata( - self.cached_registry_proto, project - ) self.cached_registry_proto = self.proto() self.cached_registry_proto_created = _utc_now() def _refresh_cached_registry_if_necessary(self): if self.cache_mode == "sync": with self._refresh_lock: - expired = ( - self.cached_registry_proto is None - or self.cached_registry_proto_created is None - ) or ( - self.cached_registry_proto_ttl.total_seconds() - > 0 # 0 ttl means infinity - and ( - _utc_now() - > ( - self.cached_registry_proto_created - + self.cached_registry_proto_ttl + if self.cached_registry_proto == RegistryProto(): + # Avoids the need to refresh the registry when cache is not populated yet + # Specially during the __init__ phase + # proto() will populate the cache with project metadata if no objects are registered + expired = False + else: + expired = ( + self.cached_registry_proto is None + or self.cached_registry_proto_created is None + ) or ( + self.cached_registry_proto_ttl.total_seconds() + > 0 # 0 ttl means infinity + and ( + _utc_now() + > ( + self.cached_registry_proto_created + + self.cached_registry_proto_ttl + ) ) ) - ) if expired: logger.info("Registry cache expired, so refreshing") self.refresh() @@ -371,7 +369,7 @@ def _start_thread_async_refresh(self, cache_ttl_seconds): self.registry_refresh_thread = threading.Timer( cache_ttl_seconds, self._start_thread_async_refresh, [cache_ttl_seconds] ) - self.registry_refresh_thread.setDaemon(True) + self.registry_refresh_thread.daemon = True self.registry_refresh_thread.start() def _exit_handler(self): diff --git a/sdk/python/feast/infra/registry/sql.py b/sdk/python/feast/infra/registry/sql.py index 5f887d194e1..2be472dd137 100644 --- a/sdk/python/feast/infra/registry/sql.py +++ b/sdk/python/feast/infra/registry/sql.py @@ -205,10 +205,6 @@ class SqlRegistryConfig(RegistryConfig): """ Dict[str, Any]: Extra arguments to pass to SQLAlchemy.create_engine. """ -# Number of workers in ThreadPoolExecutor -MAX_WORKERS = 5 - - class SqlRegistry(CachingRegistry): def __init__( self, @@ -221,6 +217,9 @@ def __init__( self.engine: Engine = create_engine( registry_config.path, **registry_config.sqlalchemy_config_kwargs ) + self.thread_pool_executor_worker_count = ( + registry_config.thread_pool_executor_worker_count + ) metadata.create_all(self.engine) self._maybe_init_project_metadata(project) @@ -368,6 +367,23 @@ def _list_entities( entities, project, EntityProto, Entity, "entity_proto", tags=tags ) + # TODO: Add to BaseRegistry + def delete_project(self, project: str): + with self.engine.begin() as conn: + for t in { + entities, + data_sources, + feature_views, + feature_services, + on_demand_feature_views, + saved_datasets, + validation_references, + managed_infra, + feast_metadata, + }: + stmt = delete(t).where(t.c.project_id == project) + conn.execute(stmt) + def delete_entity(self, name: str, project: str, commit: bool = True): return self._delete_object( entities, name, project, "entity_name", EntityNotFoundException @@ -740,10 +756,14 @@ def process_project(project_metadata: ProjectMetadata): project_metadata_list = self.get_all_projects() - with ThreadPoolExecutor( - max_workers=MAX_WORKERS - ) as executor: # Adjust max_workers as needed. Defaults to 5 - executor.map(process_project, project_metadata_list) + if self.thread_pool_executor_worker_count == 0: + for project_metadata in project_metadata_list: + process_project(project_metadata) + else: + with ThreadPoolExecutor( + max_workers=self.thread_pool_executor_worker_count + ) as executor: + executor.map(process_project, project_metadata_list) if last_updated_timestamps: r.last_updated.FromDatetime(max(last_updated_timestamps)) diff --git a/sdk/python/feast/project_metadata.py b/sdk/python/feast/project_metadata.py index b6ccc8c64b6..d8a28ded7f0 100644 --- a/sdk/python/feast/project_metadata.py +++ b/sdk/python/feast/project_metadata.py @@ -98,8 +98,8 @@ def from_proto(cls, project_metadata_proto: ProjectMetadataProto): project_metadata = cls( project_name=project_metadata_proto.project, project_uuid=project_metadata_proto.project_uuid, - last_updated_timestamp=project_metadata_proto.last_updated_timestamp.ToDatetime().astimezone( - tz=timezone.utc + last_updated_timestamp=project_metadata_proto.last_updated_timestamp.ToDatetime().replace( + tzinfo=timezone.utc ), ) diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index fc2792e3237..b18ce1fdcb9 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -128,6 +128,9 @@ class RegistryConfig(FeastBaseModel): cache_mode: StrictStr = "sync" """ str: Cache mode type, Possible options are sync and thread(asynchronous caching using threading library)""" + thread_pool_executor_worker_count: StrictInt = 0 + """ int: Number of worker threads to use for asynchronous caching in SQL Registry. If set to 0, it doesn't use ThreadPoolExecutor. """ + @field_validator("path") def validate_path(cls, path: str, values: ValidationInfo) -> str: if values.data.get("registry_type") == "sql": diff --git a/sdk/python/tests/integration/registration/test_universal_registry.py b/sdk/python/tests/integration/registration/test_universal_registry.py index 9dcd1b5b91c..29c07d1bb2d 100644 --- a/sdk/python/tests/integration/registration/test_universal_registry.py +++ b/sdk/python/tests/integration/registration/test_universal_registry.py @@ -14,7 +14,7 @@ import logging import os import time -from datetime import timedelta, timezone +from datetime import datetime, timedelta, timezone from tempfile import mkstemp from unittest import mock @@ -155,7 +155,7 @@ def pg_registry_async(): container.start() - registry_config = _given_registry_config_for_pg_sql(container, 2, "thread") + registry_config = _given_registry_config_for_pg_sql(container, 2, "thread", 3) yield SqlRegistry(registry_config, "project", None) @@ -163,7 +163,10 @@ def pg_registry_async(): def _given_registry_config_for_pg_sql( - container, cache_ttl_seconds=2, cache_mode="sync" + container, + cache_ttl_seconds=2, + cache_mode="sync", + thread_pool_executor_worker_count=0, ): log_string_to_wait_for = "database system is ready to accept connections" waited = wait_for_logs( @@ -180,6 +183,7 @@ def _given_registry_config_for_pg_sql( registry_type="sql", cache_ttl_seconds=cache_ttl_seconds, cache_mode=cache_mode, + thread_pool_executor_worker_count=thread_pool_executor_worker_count, # The `path` must include `+psycopg` in order for `sqlalchemy.create_engine()` # to understand that we are using psycopg3. path=f"postgresql+psycopg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{container_host}:{container_port}/{POSTGRES_DB}", @@ -204,14 +208,19 @@ def mysql_registry_async(): container = MySqlContainer("mysql:latest") container.start() - registry_config = _given_registry_config_for_mysql(container, 2, "thread") + registry_config = _given_registry_config_for_mysql(container, 2, "thread", 3) yield SqlRegistry(registry_config, "project", None) container.stop() -def _given_registry_config_for_mysql(container, cache_ttl_seconds=2, cache_mode="sync"): +def _given_registry_config_for_mysql( + container, + cache_ttl_seconds=2, + cache_mode="sync", + thread_pool_executor_worker_count=0, +): import sqlalchemy engine = sqlalchemy.create_engine( @@ -224,11 +233,12 @@ def _given_registry_config_for_mysql(container, cache_ttl_seconds=2, cache_mode= path=container.get_connection_url(), cache_ttl_seconds=cache_ttl_seconds, cache_mode=cache_mode, + thread_pool_executor_worker_count=thread_pool_executor_worker_count, sqlalchemy_config_kwargs={"echo": False, "pool_pre_ping": True}, ) -@pytest.fixture(scope="session") +@pytest.fixture(scope="function") def sqlite_registry(): registry_config = RegistryConfig( registry_type="sql", @@ -342,6 +352,7 @@ def test_apply_entity_success(test_registry): project_uuid = project_metadata[0].project_uuid assert len(project_metadata[0].project_uuid) == 36 assert_project_uuid(project, project_uuid, test_registry) + assert project_metadata[0].last_updated_timestamp is not None entities = test_registry.list_entities(project, tags=entity.tags) assert_project_uuid(project, project_uuid, test_registry) @@ -1343,3 +1354,104 @@ def validate_project_uuid(project_uuid, test_registry): assert len(test_registry.cached_registry_proto.project_metadata) == 1 project_metadata = test_registry.cached_registry_proto.project_metadata[0] assert project_metadata.project_uuid == project_uuid + + +@pytest.mark.integration +@pytest.mark.parametrize( + "test_registry", + sql_fixtures, +) +def test_project_metadata_success(test_registry): + project = "project" + project_metadata = test_registry.get_project_metadata(project) + assert project_metadata.project_name == project + assert project_metadata.last_updated_timestamp == datetime.fromtimestamp( + 1, tz=timezone.utc + ) + + last_refresh_timestamp = project_metadata.last_updated_timestamp + + entity = Entity( + name="test_project_metadata_success", + description="test_project_metadata_success", + tags={"team": "matchmaking"}, + ) + + # Register Entity + test_registry.apply_entity(entity, project) + + project_metadata = test_registry.get_project_metadata(project) + assert project_metadata.project_name == project + assert project_metadata.last_updated_timestamp > last_refresh_timestamp + + project_metadata_list = test_registry.get_all_projects() + assert len(project_metadata_list) == 1 + + test_registry.delete_project(project) + + project_metadata = test_registry.get_project_metadata(project) + assert project_metadata is None + + project_metadata_list = test_registry.get_all_projects() + assert len(project_metadata_list) == 0 + + test_registry.teardown() + + +@pytest.mark.integration +@pytest.mark.parametrize( + "test_registry", + sql_fixtures, +) +def test_project_metadata_from_cache_on_init_success(test_registry): + # In Setup phase, proto() method is not executing fully due to lazy fixtures, so forcing the call + test_registry.cached_registry_proto = test_registry.proto() + project = "project" + project_metadata = test_registry.get_project_metadata(project, allow_cache=True) + assert project_metadata.project_name == project + assert project_metadata.last_updated_timestamp == datetime.fromtimestamp( + 1, tz=timezone.utc + ) + last_refresh_timestamp = project_metadata.last_updated_timestamp + + entity = Entity( + name="test_project_metadata_from_cache_on_init_success", + description="test_project_metadata_from_cache_on_init_success", + tags={"team": "matchmaking"}, + ) + # Register Entity + test_registry.apply_entity(entity, project) + + project_metadata = test_registry.get_project_metadata(project) + assert project_metadata.project_name == project + assert project_metadata.last_updated_timestamp > last_refresh_timestamp + + test_registry.refresh() + project_metadata = test_registry.get_project_metadata(project, allow_cache=True) + assert project_metadata.project_name == project + assert project_metadata.last_updated_timestamp > last_refresh_timestamp + + project_metadata_list = test_registry.get_all_projects() + assert len(project_metadata_list) == 1 + + test_registry.teardown() + + +@pytest.mark.integration +@pytest.mark.parametrize( + "test_registry", + async_sql_fixtures, +) +def test_registry_cache_project_metadata_thread_async(test_registry): + project = "project" + # Wait for cache to be refreshed + time.sleep(4) + # Now objects exist + project_metadata = test_registry.get_project_metadata(project, allow_cache=True) + assert project_metadata is not None + assert project_metadata.project_name == project + + project_metadata_list = test_registry.get_all_projects() + assert len(project_metadata_list) == 1 + + test_registry.teardown() diff --git a/sdk/python/tests/unit/test_on_demand_feature_view.py b/sdk/python/tests/unit/test_on_demand_feature_view.py index d9cc5dee50d..7717a184386 100644 --- a/sdk/python/tests/unit/test_on_demand_feature_view.py +++ b/sdk/python/tests/unit/test_on_demand_feature_view.py @@ -251,11 +251,9 @@ def test_from_proto_backwards_compatible_udf(): proto.spec.feature_transformation.user_defined_function.body_text ) - # And now we're going to null the feature_transformation proto object before reserializing the entire proto - # proto.spec.user_defined_function.body_text = on_demand_feature_view.transformation.udf_string - proto.spec.feature_transformation.user_defined_function.name = "" - proto.spec.feature_transformation.user_defined_function.body = b"" - proto.spec.feature_transformation.user_defined_function.body_text = "" + # For objects that are already registered, feature_transformation and mode is not set + proto.spec.feature_transformation.Clear() + proto.spec.ClearField("mode") # And now we expect the to get the same object back under feature_transformation reserialized_proto = OnDemandFeatureView.from_proto(proto) @@ -263,3 +261,4 @@ def test_from_proto_backwards_compatible_udf(): reserialized_proto.feature_transformation.udf_string == on_demand_feature_view.feature_transformation.udf_string ) + assert reserialized_proto.mode == "pandas" diff --git a/sdk/python/tests/unit/test_project_metadata.py b/sdk/python/tests/unit/test_project_metadata.py new file mode 100644 index 00000000000..ec000623589 --- /dev/null +++ b/sdk/python/tests/unit/test_project_metadata.py @@ -0,0 +1,99 @@ +import unittest +from datetime import datetime, timezone + +from feast.project_metadata import ProjectMetadata +from feast.protos.feast.core.Registry_pb2 import ProjectMetadata as ProjectMetadataProto + + +class TestProjectMetadata(unittest.TestCase): + def setUp(self): + self.project_name = "test_project" + self.project_uuid = "123e4567-e89b-12d3-a456-426614174000" + self.timestamp = datetime(2021, 1, 1, tzinfo=timezone.utc) + + def test_initialization(self): + metadata = ProjectMetadata( + project_name=self.project_name, + project_uuid=self.project_uuid, + last_updated_timestamp=self.timestamp, + ) + self.assertEqual(metadata.project_name, self.project_name) + self.assertEqual(metadata.project_uuid, self.project_uuid) + self.assertEqual(metadata.last_updated_timestamp, self.timestamp) + + def test_initialization_with_default_last_updated_timestamp(self): + metadata = ProjectMetadata( + project_name=self.project_name, + project_uuid=self.project_uuid, + ) + self.assertEqual(metadata.project_name, self.project_name) + self.assertEqual(metadata.project_uuid, self.project_uuid) + self.assertEqual( + metadata.last_updated_timestamp, datetime.fromtimestamp(1, tz=timezone.utc) + ) + + def test_initialization_without_project_name(self): + with self.assertRaises(ValueError): + ProjectMetadata() + + def test_equality(self): + metadata1 = ProjectMetadata( + project_name=self.project_name, + project_uuid=self.project_uuid, + last_updated_timestamp=self.timestamp, + ) + metadata2 = ProjectMetadata( + project_name=self.project_name, + project_uuid=self.project_uuid, + last_updated_timestamp=self.timestamp, + ) + self.assertEqual(metadata1, metadata2) + + def test_hash(self): + metadata = ProjectMetadata( + project_name=self.project_name, + project_uuid=self.project_uuid, + last_updated_timestamp=self.timestamp, + ) + self.assertEqual( + hash(metadata), hash((self.project_name, self.project_uuid, self.timestamp)) + ) + + def test_from_proto(self): + proto = ProjectMetadataProto( + project=self.project_name, + project_uuid=self.project_uuid, + ) + proto.last_updated_timestamp.FromDatetime(self.timestamp) + metadata = ProjectMetadata.from_proto(proto) + self.assertEqual(metadata.project_name, self.project_name) + self.assertEqual(metadata.project_uuid, self.project_uuid) + self.assertEqual(metadata.last_updated_timestamp, self.timestamp) + + def test_to_proto(self): + metadata = ProjectMetadata( + project_name=self.project_name, + project_uuid=self.project_uuid, + last_updated_timestamp=self.timestamp, + ) + proto = metadata.to_proto() + self.assertEqual(proto.project, self.project_name) + self.assertEqual(proto.project_uuid, self.project_uuid) + self.assertEqual( + proto.last_updated_timestamp.ToDatetime().replace(tzinfo=timezone.utc), + self.timestamp, + ) + + def test_conversion_to_proto_and_back(self): + metadata = ProjectMetadata( + project_name=self.project_name, + project_uuid=self.project_uuid, + last_updated_timestamp=self.timestamp, + ) + proto = metadata.to_proto() + metadata_from_proto = ProjectMetadata.from_proto(proto) + self.assertEqual(metadata, metadata_from_proto) + + +if __name__ == "__main__": + unittest.main()