From 7c951eb570d3391f662eb638a219490011178643 Mon Sep 17 00:00:00 2001 From: jyejare Date: Thu, 12 Jun 2025 17:14:54 +0530 Subject: [PATCH 1/6] Remotely retrieve the docs Signed-off-by: jyejare --- sdk/python/feast/feature_server.py | 13 +- .../feast/infra/online_stores/remote.py | 197 ++++++++++++++++++ 2 files changed, 207 insertions(+), 3 deletions(-) diff --git a/sdk/python/feast/feature_server.py b/sdk/python/feast/feature_server.py index a095255d5af..febf0fb6e2e 100644 --- a/sdk/python/feast/feature_server.py +++ b/sdk/python/feast/feature_server.py @@ -88,6 +88,13 @@ class GetOnlineFeaturesRequest(BaseModel): feature_service: Optional[str] = None features: Optional[List[str]] = None full_feature_names: bool = False + + +class GetOnlineDocumentsRequest(BaseModel): + feature_service: Optional[str] = None + features: Optional[List[str]] = None + full_feature_names: bool = False + top_k: Optional[int] = None query_embedding: Optional[List[float]] = None query_string: Optional[str] = None @@ -110,7 +117,7 @@ class SaveDocumentRequest(BaseModel): data: dict -def _get_features(request: GetOnlineFeaturesRequest, store: "feast.FeatureStore"): +def _get_features(request: GetOnlineFeaturesRequest|GetOnlineDocumentsRequest, store: "feast.FeatureStore"): if request.feature_service: feature_service = store.get_feature_service( request.feature_service, allow_cache=True @@ -246,7 +253,7 @@ async def get_online_features(request: GetOnlineFeaturesRequest) -> Dict[str, An dependencies=[Depends(inject_user_details)], ) async def retrieve_online_documents( - request: GetOnlineFeaturesRequest, + request: GetOnlineDocumentsRequest, ) -> Dict[str, Any]: logger.warning( "This endpoint is in alpha and will be moved to /get-online-features when stable." @@ -256,9 +263,9 @@ async def retrieve_online_documents( read_params = dict( features=features, - full_feature_names=request.full_feature_names, query=request.query_embedding, query_string=request.query_string, + top_k=request.top_k, ) response = await run_in_threadpool( diff --git a/sdk/python/feast/infra/online_stores/remote.py b/sdk/python/feast/infra/online_stores/remote.py index ea09362299d..0cbbb7632ba 100644 --- a/sdk/python/feast/infra/online_stores/remote.py +++ b/sdk/python/feast/infra/online_stores/remote.py @@ -170,6 +170,123 @@ def online_read( logger.error(error_msg) raise RuntimeError(error_msg) + def retrieve_online_documents( + self, + config: RepoConfig, + table: FeatureView, + embedding: List[float], + top_k: int, + requested_features: Optional[List[str]] = None, + distance_metric: Optional[str] = "L2", + ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: + assert isinstance(config.online_store, RemoteOnlineStoreConfig) + config.online_store.__class__ = RemoteOnlineStoreConfig + + req_body = self._construct_online_documents_api_json_request( + table, requested_features, embedding, top_k, distance_metric + ) + response = get_remote_online_documents(config=config, req_body=req_body) + if response.status_code == 200: + logger.debug("Able to retrieve the online documents from feature server.") + response_json = json.loads(response.text) + event_ts = self._get_event_ts(response_json) + result_tuples: List[ + Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]] + ] = [] + for feature_value_index in range( + len(response_json["results"][0]["values"]) + ): + feature_values_dict: Dict[str, ValueProto] = dict() + for index, feature_name in enumerate( + response_json["metadata"]["feature_names"] + ): + if ( + requested_features is not None + and feature_name in requested_features + ): + if ( + response_json["results"][index]["statuses"][ + feature_value_index + ] + == "PRESENT" + ): + message = python_values_to_proto_values( + [ + response_json["results"][index]["values"][ + feature_value_index + ] + ], + ValueType.UNKNOWN, + ) + feature_values_dict[feature_name] = message[0] + else: + feature_values_dict[feature_name] = ValueProto() + result_tuples.append((event_ts, feature_values_dict)) + return result_tuples + else: + error_msg = f"Unable to retrieve the online documents using feature server API. Error_code={response.status_code}, error_message={response.text}" + logger.error(error_msg) + raise RuntimeError(error_msg) + + def retrieve_online_documents_v2( + self, + config: RepoConfig, + table: FeatureView, + embedding: Optional[List[float]], + top_k: int, + requested_features: Optional[List[str]] = None, + distance_metric: Optional[str] = None, + query_string: Optional[str] = None, + ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: + assert isinstance(config.online_store, RemoteOnlineStoreConfig) + config.online_store.__class__ = RemoteOnlineStoreConfig + + req_body = self._construct_online_documents_v2_api_json_request( + table, requested_features, embedding, top_k, distance_metric, query_string + ) + response = get_remote_online_documents_v2(config=config, req_body=req_body) + if response.status_code == 200: + logger.debug("Able to retrieve the online documents from feature server.") + response_json = json.loads(response.text) + event_ts = self._get_event_ts(response_json) + result_tuples: List[ + Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]] + ] = [] + for feature_value_index in range( + len(response_json["results"][0]["values"]) + ): + feature_values_dict: Dict[str, ValueProto] = dict() + for index, feature_name in enumerate( + response_json["metadata"]["feature_names"] + ): + if ( + requested_features is not None + and feature_name in requested_features + ): + if ( + response_json["results"][index]["statuses"][ + feature_value_index + ] + == "PRESENT" + ): + message = python_values_to_proto_values( + [ + response_json["results"][index]["values"][ + feature_value_index + ] + ], + ValueType.UNKNOWN, + ) + feature_values_dict[feature_name] = message[0] + else: + feature_values_dict[feature_name] = ValueProto() + result_tuples.append((event_ts, feature_values_dict)) + return result_tuples + else: + error_msg = f"Unable to retrieve the online documents using feature server API. Error_code={response.status_code}, error_message={response.text}" + logger.error(error_msg) + raise RuntimeError(error_msg) + def _construct_online_read_api_json_request( self, entity_keys: List[EntityKeyProto], @@ -197,6 +314,54 @@ def _construct_online_read_api_json_request( ) return req_body + def _construct_online_documents_api_json_request( + self, + table: FeatureView, + requested_features: Optional[List[str]] = None, + embedding: Optional[List[float]] = None, + top_k: Optional[int] = None, + distance_metric: Optional[str] = "L2", + ) -> str: + api_requested_features = [] + if requested_features is not None: + for requested_feature in requested_features: + api_requested_features.append(f"{table.name}:{requested_feature}") + + req_body = json.dumps( + { + "features": api_requested_features, + "embedding": embedding, + "top_k": top_k, + "distance_metric": distance_metric, + } + ) + return req_body + + def _construct_online_documents_v2_api_json_request( + self, + table: FeatureView, + embedding: Optional[List[float]], + top_k: int, + requested_features: Optional[List[str]] = None, + distance_metric: Optional[str] = None, + query_string: Optional[str] = None, + ) -> str: + api_requested_features = [] + if requested_features is not None: + for requested_feature in requested_features: + api_requested_features.append(f"{table.name}:{requested_feature}") + + req_body = json.dumps( + { + "features": api_requested_features, + "embedding": embedding, + "top_k": top_k, + "distance_metric": distance_metric, + "query_string": query_string, + } + ) + return req_body + def _get_event_ts(self, response_json) -> datetime: event_ts = "" if len(response_json["results"]) > 1: @@ -239,6 +404,38 @@ def get_remote_online_features( ) +@rest_error_handling_decorator +def get_remote_online_documents( + session: requests.Session, config: RepoConfig, req_body: str +) -> requests.Response: + if config.online_store.cert: + return session.post( + f"{config.online_store.path}/retrieve-online-documents", + data=req_body, + verify=config.online_store.cert, + ) + else: + return session.post( + f"{config.online_store.path}/retrieve-online-documents", data=req_body + ) + + +@rest_error_handling_decorator +def get_remote_online_documents_v2( + session: requests.Session, config: RepoConfig, req_body: str +) -> requests.Response: + if config.online_store.cert: + return session.post( + f"{config.online_store.path}/retrieve-online-documents", + data=req_body, + verify=config.online_store.cert, + ) + else: + return session.post( + f"{config.online_store.path}/retrieve-online-documents", data=req_body + ) + + @rest_error_handling_decorator def post_remote_online_write( session: requests.Session, config: RepoConfig, req_body: dict From 7839fc529a38d21379e425a98bb62ff655466085 Mon Sep 17 00:00:00 2001 From: jyejare Date: Tue, 17 Jun 2025 19:02:47 +0530 Subject: [PATCH 2/6] Feature server rectified for documents retrival Signed-off-by: jyejare --- sdk/python/feast/feature_server.py | 8 +++++--- sdk/python/feast/infra/online_stores/remote.py | 8 ++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/sdk/python/feast/feature_server.py b/sdk/python/feast/feature_server.py index febf0fb6e2e..31132b6729b 100644 --- a/sdk/python/feast/feature_server.py +++ b/sdk/python/feast/feature_server.py @@ -117,7 +117,10 @@ class SaveDocumentRequest(BaseModel): data: dict -def _get_features(request: GetOnlineFeaturesRequest|GetOnlineDocumentsRequest, store: "feast.FeatureStore"): +def _get_features( + request: GetOnlineFeaturesRequest | GetOnlineDocumentsRequest, + store: "feast.FeatureStore", +): if request.feature_service: feature_service = store.get_feature_service( request.feature_service, allow_cache=True @@ -264,12 +267,11 @@ async def retrieve_online_documents( read_params = dict( features=features, query=request.query_embedding, - query_string=request.query_string, top_k=request.top_k, ) response = await run_in_threadpool( - lambda: store.retrieve_online_documents_v2(**read_params) # type: ignore + lambda: store.retrieve_online_documents(**read_params) # type: ignore ) # Convert the Protobuf object to JSON and return it diff --git a/sdk/python/feast/infra/online_stores/remote.py b/sdk/python/feast/infra/online_stores/remote.py index 0cbbb7632ba..c25ac3f4d72 100644 --- a/sdk/python/feast/infra/online_stores/remote.py +++ b/sdk/python/feast/infra/online_stores/remote.py @@ -174,11 +174,11 @@ def retrieve_online_documents( self, config: RepoConfig, table: FeatureView, - embedding: List[float], + requested_features: Optional[List[str]], + embedding: Optional[List[float]], top_k: int, - requested_features: Optional[List[str]] = None, distance_metric: Optional[str] = "L2", - ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: + ) -> List[Tuple[Optional[datetime], Optional[EntityKeyProto], Optional[Dict[str, ValueProto]]]]: assert isinstance(config.online_store, RemoteOnlineStoreConfig) config.online_store.__class__ = RemoteOnlineStoreConfig @@ -330,7 +330,7 @@ def _construct_online_documents_api_json_request( req_body = json.dumps( { "features": api_requested_features, - "embedding": embedding, + "query_embedding": embedding, "top_k": top_k, "distance_metric": distance_metric, } From 18bfcf476e17ac6557fecc60bdcca278e8d22ae9 Mon Sep 17 00:00:00 2001 From: jyejare Date: Thu, 26 Jun 2025 19:08:50 +0530 Subject: [PATCH 3/6] Updates to remote retrival for response extraction Signed-off-by: jyejare --- .../feast/infra/online_stores/remote.py | 167 ++++++++++++++---- 1 file changed, 131 insertions(+), 36 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/remote.py b/sdk/python/feast/infra/online_stores/remote.py index c25ac3f4d72..68b45c619f3 100644 --- a/sdk/python/feast/infra/online_stores/remote.py +++ b/sdk/python/feast/infra/online_stores/remote.py @@ -31,6 +31,7 @@ feast_value_type_to_python_type, python_values_to_proto_values, ) +from feast.utils import _get_feature_view_vector_field_metadata from feast.value_type import ValueType logger = logging.getLogger(__name__) @@ -178,7 +179,7 @@ def retrieve_online_documents( embedding: Optional[List[float]], top_k: int, distance_metric: Optional[str] = "L2", - ) -> List[Tuple[Optional[datetime], Optional[EntityKeyProto], Optional[Dict[str, ValueProto]]]]: + ) -> List[Tuple[Optional[datetime], Optional[EntityKeyProto], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto]]]: assert isinstance(config.online_store, RemoteOnlineStoreConfig) config.online_store.__class__ = RemoteOnlineStoreConfig @@ -190,38 +191,38 @@ def retrieve_online_documents( logger.debug("Able to retrieve the online documents from feature server.") response_json = json.loads(response.text) event_ts = self._get_event_ts(response_json) - result_tuples: List[ - Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]] - ] = [] - for feature_value_index in range( - len(response_json["results"][0]["values"]) - ): - feature_values_dict: Dict[str, ValueProto] = dict() - for index, feature_name in enumerate( - response_json["metadata"]["feature_names"] - ): - if ( - requested_features is not None - and feature_name in requested_features - ): - if ( - response_json["results"][index]["statuses"][ - feature_value_index - ] - == "PRESENT" - ): - message = python_values_to_proto_values( - [ - response_json["results"][index]["values"][ - feature_value_index - ] - ], - ValueType.UNKNOWN, - ) - feature_values_dict[feature_name] = message[0] - else: - feature_values_dict[feature_name] = ValueProto() - result_tuples.append((event_ts, feature_values_dict)) + + # Create feature name to index mapping for efficient lookup + feature_name_to_index = { + name: idx for idx, name in enumerate(response_json["metadata"]["feature_names"]) + } + + vector_field_metadata = _get_feature_view_vector_field_metadata(table) + + # Extract feature names once + feature_names = response_json["metadata"]["feature_names"] + + # Process each result row + num_results = len(response_json["results"][0]["values"]) + result_tuples = [] + + for row_idx in range(num_results): + # Extract values using helper methods + feature_val = self._extract_requested_feature_value( + response_json, feature_name_to_index, requested_features, row_idx + ) + vector_value = self._extract_vector_field_value( + response_json, feature_name_to_index, vector_field_metadata, row_idx + ) + distance_val = self._extract_distance_value( + response_json, feature_name_to_index, 'distance', row_idx + ) + entity_key_proto = self._construct_entity_key_from_response( + response_json, row_idx, feature_name_to_index + ) + + result_tuples.append((event_ts, entity_key_proto, feature_val, vector_value, distance_val)) + return result_tuples else: error_msg = f"Unable to retrieve the online documents using feature server API. Error_code={response.status_code}, error_message={response.text}" @@ -237,7 +238,7 @@ def retrieve_online_documents_v2( requested_features: Optional[List[str]] = None, distance_metric: Optional[str] = None, query_string: Optional[str] = None, - ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: + ) -> List[Tuple[Optional[datetime], Optional[EntityKeyProto], Optional[Dict[str, ValueProto]]]]: assert isinstance(config.online_store, RemoteOnlineStoreConfig) config.online_store.__class__ = RemoteOnlineStoreConfig @@ -250,7 +251,7 @@ def retrieve_online_documents_v2( response_json = json.loads(response.text) event_ts = self._get_event_ts(response_json) result_tuples: List[ - Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]] + Tuple[Optional[datetime], Optional[EntityKeyProto], Optional[Dict[str, ValueProto]]] ] = [] for feature_value_index in range( len(response_json["results"][0]["values"]) @@ -280,13 +281,78 @@ def retrieve_online_documents_v2( feature_values_dict[feature_name] = message[0] else: feature_values_dict[feature_name] = ValueProto() - result_tuples.append((event_ts, feature_values_dict)) + + # Create a dummy EntityKeyProto since remote store doesn't provide entity information + # This matches the behavior of the current implementation + entity_key_proto = None + + result_tuples.append((event_ts, entity_key_proto, feature_values_dict)) return result_tuples else: error_msg = f"Unable to retrieve the online documents using feature server API. Error_code={response.status_code}, error_message={response.text}" logger.error(error_msg) raise RuntimeError(error_msg) + def _extract_requested_feature_value( + self, + response_json: dict, + feature_name_to_index: dict, + requested_features: Optional[List[str]], + row_idx: int + ) -> ValueProto: + """Extract the first available requested feature value.""" + if not requested_features: + return ValueProto() + + for feature_name in requested_features: + if feature_name in feature_name_to_index: + feature_idx = feature_name_to_index[feature_name] + if self._is_feature_present(response_json, feature_idx, row_idx): + return self._extract_feature_value(response_json, feature_idx, row_idx) + + return ValueProto() + + def _extract_vector_field_value( + self, + response_json: dict, + feature_name_to_index: dict, + vector_field_metadata, + row_idx: int + ) -> ValueProto: + """Extract vector field value from response.""" + if not vector_field_metadata or vector_field_metadata.name not in feature_name_to_index: + return ValueProto() + + vector_feature_idx = feature_name_to_index[vector_field_metadata.name] + if self._is_feature_present(response_json, vector_feature_idx, row_idx): + return self._extract_feature_value(response_json, vector_feature_idx, row_idx) + + return ValueProto() + + def _extract_distance_value( + self, + response_json: dict, + feature_name_to_index: dict, + distance_feature_name: str, + row_idx: int + ) -> ValueProto: + """Extract distance/score value from response.""" + if not distance_feature_name: + return ValueProto() + + distance_feature_idx = feature_name_to_index[distance_feature_name] + if self._is_feature_present(response_json, distance_feature_idx, row_idx): + distance_value = response_json["results"][distance_feature_idx]["values"][row_idx] + distance_val = ValueProto() + distance_val.float_val = float(distance_value) + return distance_val + + return ValueProto() + + def _is_feature_present(self, response_json: dict, feature_idx: int, row_idx: int) -> bool: + """Check if a feature is present in the response.""" + return response_json["results"][feature_idx]["statuses"][row_idx] == "PRESENT" + def _construct_online_read_api_json_request( self, entity_keys: List[EntityKeyProto], @@ -368,6 +434,35 @@ def _get_event_ts(self, response_json) -> datetime: event_ts = response_json["results"][1]["event_timestamps"][0] return datetime.fromisoformat(event_ts.replace("Z", "+00:00")) + def _construct_entity_key_from_response( + self, response_json: dict, row_idx: int, feature_name_to_index: dict + ) -> Optional[EntityKeyProto]: + """Construct EntityKeyProto from response data.""" + # Look for entity key fields in the response + entity_fields = [name for name in feature_name_to_index.keys() + if name.endswith('_id') or name in ['id', 'key', 'entity_id']] + + if not entity_fields: + return None + + entity_key_proto = EntityKeyProto() + entity_key_proto.join_keys.extend(entity_fields) + + for entity_field in entity_fields: + if entity_field in feature_name_to_index: + feature_idx = feature_name_to_index[entity_field] + if self._is_feature_present(response_json, feature_idx, row_idx): + entity_value = self._extract_feature_value(response_json, feature_idx, row_idx) + entity_key_proto.entity_values.append(entity_value) + + return entity_key_proto if entity_key_proto.entity_values else None + + def _extract_feature_value(self, response_json: dict, feature_idx: int, row_idx: int) -> ValueProto: + """Extract and convert a feature value to ValueProto.""" + raw_value = response_json["results"][feature_idx]["values"][row_idx] + proto_values = python_values_to_proto_values([raw_value]) + return proto_values[0] + def update( self, config: RepoConfig, From 2043b09d4216499837d879ae31eb170974987d67 Mon Sep 17 00:00:00 2001 From: jyejare Date: Fri, 27 Jun 2025 00:05:55 +0530 Subject: [PATCH 4/6] Remote retrive online doc v2 Signed-off-by: jyejare --- sdk/python/feast/feature_server.py | 20 ++-- sdk/python/feast/feature_store.py | 8 +- .../feast/infra/online_stores/remote.py | 95 ++++++++----------- 3 files changed, 60 insertions(+), 63 deletions(-) diff --git a/sdk/python/feast/feature_server.py b/sdk/python/feast/feature_server.py index 31132b6729b..9c1613a46c4 100644 --- a/sdk/python/feast/feature_server.py +++ b/sdk/python/feast/feature_server.py @@ -95,8 +95,9 @@ class GetOnlineDocumentsRequest(BaseModel): features: Optional[List[str]] = None full_feature_names: bool = False top_k: Optional[int] = None - query_embedding: Optional[List[float]] = None + query: Optional[List[float]] = None query_string: Optional[str] = None + api_version: Optional[int] = 1 class ChatMessage(BaseModel): @@ -266,13 +267,20 @@ async def retrieve_online_documents( read_params = dict( features=features, - query=request.query_embedding, - top_k=request.top_k, + query=request.query, + top_k=request.top_k ) + if request.api_version == 2 and request.query_string is not None: + read_params['query_string'] = request.query_string - response = await run_in_threadpool( - lambda: store.retrieve_online_documents(**read_params) # type: ignore - ) + if request.api_version == 2: + response = await run_in_threadpool( + lambda: store.retrieve_online_documents_v2(**read_params) # type: ignore + ) + else: + response = await run_in_threadpool( + lambda: store.retrieve_online_documents(**read_params) # type: ignore + ) # Convert the Protobuf object to JSON and return it response_dict = await run_in_threadpool( diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 326daa589e8..40acb855854 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -2275,7 +2275,7 @@ def retrieve_online_documents_v2( distance_metric, query_string, ) - + def _retrieve_from_online_store( self, provider: Provider, @@ -2413,6 +2413,12 @@ def _retrieve_from_online_store_v2( output_len=output_len, ) + utils._populate_result_rows_from_columnar( + online_features_response=online_features_response, + data=entity_key_dict, + ) + + return OnlineResponse(online_features_response) def serve( diff --git a/sdk/python/feast/infra/online_stores/remote.py b/sdk/python/feast/infra/online_stores/remote.py index 68b45c619f3..026bf86975e 100644 --- a/sdk/python/feast/infra/online_stores/remote.py +++ b/sdk/python/feast/infra/online_stores/remote.py @@ -233,9 +233,9 @@ def retrieve_online_documents_v2( self, config: RepoConfig, table: FeatureView, + requested_features: Optional[List[str]], embedding: Optional[List[float]], top_k: int, - requested_features: Optional[List[str]] = None, distance_metric: Optional[str] = None, query_string: Optional[str] = None, ) -> List[Tuple[Optional[datetime], Optional[EntityKeyProto], Optional[Dict[str, ValueProto]]]]: @@ -243,50 +243,45 @@ def retrieve_online_documents_v2( config.online_store.__class__ = RemoteOnlineStoreConfig req_body = self._construct_online_documents_v2_api_json_request( - table, requested_features, embedding, top_k, distance_metric, query_string + table, requested_features, embedding, top_k, distance_metric, query_string, api_version=2 ) - response = get_remote_online_documents_v2(config=config, req_body=req_body) + response = get_remote_online_documents(config=config, req_body=req_body) if response.status_code == 200: logger.debug("Able to retrieve the online documents from feature server.") response_json = json.loads(response.text) event_ts = self._get_event_ts(response_json) - result_tuples: List[ - Tuple[Optional[datetime], Optional[EntityKeyProto], Optional[Dict[str, ValueProto]]] - ] = [] - for feature_value_index in range( - len(response_json["results"][0]["values"]) - ): - feature_values_dict: Dict[str, ValueProto] = dict() - for index, feature_name in enumerate( - response_json["metadata"]["feature_names"] - ): - if ( - requested_features is not None - and feature_name in requested_features - ): - if ( - response_json["results"][index]["statuses"][ - feature_value_index - ] - == "PRESENT" - ): - message = python_values_to_proto_values( - [ - response_json["results"][index]["values"][ - feature_value_index - ] - ], - ValueType.UNKNOWN, - ) - feature_values_dict[feature_name] = message[0] - else: - feature_values_dict[feature_name] = ValueProto() + + # Create feature name to index mapping for efficient lookup + feature_name_to_index = { + name: idx for idx, name in enumerate(response_json["metadata"]["feature_names"]) + } - # Create a dummy EntityKeyProto since remote store doesn't provide entity information - # This matches the behavior of the current implementation - entity_key_proto = None + # Process each result row + num_results = len(response_json["results"][0]["values"]) if response_json["results"] else 0 + result_tuples = [] + + for row_idx in range(num_results): + # Build feature values dictionary for requested features + feature_values_dict: Dict[str, ValueProto] = {} + + if requested_features: + for feature_name in requested_features: + if feature_name in feature_name_to_index: + feature_idx = feature_name_to_index[feature_name] + if self._is_feature_present(response_json, feature_idx, row_idx): + feature_values_dict[feature_name] = self._extract_feature_value( + response_json, feature_idx, row_idx + ) + else: + feature_values_dict[feature_name] = ValueProto() + + # Construct entity key proto using existing helper method + entity_key_proto = self._construct_entity_key_from_response( + response_json, row_idx, feature_name_to_index + ) result_tuples.append((event_ts, entity_key_proto, feature_values_dict)) + return result_tuples else: error_msg = f"Unable to retrieve the online documents using feature server API. Error_code={response.status_code}, error_message={response.text}" @@ -396,7 +391,7 @@ def _construct_online_documents_api_json_request( req_body = json.dumps( { "features": api_requested_features, - "query_embedding": embedding, + "query": embedding, "top_k": top_k, "distance_metric": distance_metric, } @@ -406,11 +401,12 @@ def _construct_online_documents_api_json_request( def _construct_online_documents_v2_api_json_request( self, table: FeatureView, + requested_features: Optional[List[str]], embedding: Optional[List[float]], top_k: int, - requested_features: Optional[List[str]] = None, distance_metric: Optional[str] = None, query_string: Optional[str] = None, + api_version: Optional[int] = 1, ) -> str: api_requested_features = [] if requested_features is not None: @@ -420,10 +416,11 @@ def _construct_online_documents_v2_api_json_request( req_body = json.dumps( { "features": api_requested_features, - "embedding": embedding, + "query": embedding, "top_k": top_k, "distance_metric": distance_metric, "query_string": query_string, + "api_version": api_version, } ) return req_body @@ -460,6 +457,8 @@ def _construct_entity_key_from_response( def _extract_feature_value(self, response_json: dict, feature_idx: int, row_idx: int) -> ValueProto: """Extract and convert a feature value to ValueProto.""" raw_value = response_json["results"][feature_idx]["values"][row_idx] + if raw_value is None: + return ValueProto() proto_values = python_values_to_proto_values([raw_value]) return proto_values[0] @@ -515,22 +514,6 @@ def get_remote_online_documents( ) -@rest_error_handling_decorator -def get_remote_online_documents_v2( - session: requests.Session, config: RepoConfig, req_body: str -) -> requests.Response: - if config.online_store.cert: - return session.post( - f"{config.online_store.path}/retrieve-online-documents", - data=req_body, - verify=config.online_store.cert, - ) - else: - return session.post( - f"{config.online_store.path}/retrieve-online-documents", data=req_body - ) - - @rest_error_handling_decorator def post_remote_online_write( session: requests.Session, config: RepoConfig, req_body: dict From 8bcb2ae2d34fd19f12131d7c75fc7ff7dac4a052 Mon Sep 17 00:00:00 2001 From: jyejare Date: Tue, 1 Jul 2025 13:06:31 +0530 Subject: [PATCH 5/6] Unit tests for Remote docuemnts retrival Signed-off-by: jyejare --- .../online_store/test_remote_online_store.py | 366 ++++++++++++++++++ 1 file changed, 366 insertions(+) create mode 100644 sdk/python/tests/unit/infra/online_store/test_remote_online_store.py diff --git a/sdk/python/tests/unit/infra/online_store/test_remote_online_store.py b/sdk/python/tests/unit/infra/online_store/test_remote_online_store.py new file mode 100644 index 00000000000..f8b59db79bd --- /dev/null +++ b/sdk/python/tests/unit/infra/online_store/test_remote_online_store.py @@ -0,0 +1,366 @@ +import json +import pytest +from datetime import datetime, timedelta +from unittest.mock import Mock, patch, MagicMock +from typing import List, Optional + +from feast import Entity, FeatureView, Field, FileSource, RepoConfig +from feast.infra.online_stores.remote import RemoteOnlineStore, RemoteOnlineStoreConfig +from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto +from feast.protos.feast.types.Value_pb2 import Value as ValueProto +from feast.types import Float32, String, Int64 +from feast.value_type import ValueType + + +class TestRemoteOnlineStoreRetrieveDocuments: + """Test suite for retrieve_online_documents and retrieve_online_documents_v2 methods.""" + + @pytest.fixture + def remote_store(self): + """Create a RemoteOnlineStore instance for testing.""" + return RemoteOnlineStore() + + @pytest.fixture + def config(self): + """Create a RepoConfig with RemoteOnlineStoreConfig.""" + return RepoConfig( + project="test_project", + online_store=RemoteOnlineStoreConfig( + type="remote", + path="http://localhost:6566" + ), + registry="dummy_registry" + ) + + @pytest.fixture + def config_with_cert(self): + """Create a RepoConfig with RemoteOnlineStoreConfig including TLS cert.""" + return RepoConfig( + project="test_project", + online_store=RemoteOnlineStoreConfig( + type="remote", + path="http://localhost:6566", + cert="/path/to/cert.pem" + ), + registry="dummy_registry" + ) + + @pytest.fixture + def feature_view(self): + """Create a test FeatureView.""" + entity = Entity(name="user_id", description="User ID", value_type=ValueType.INT64) + source = FileSource( + path="test.parquet", + timestamp_field="event_timestamp" + ) + return FeatureView( + name="test_feature_view", + entities=[entity], + ttl=timedelta(days=1), + schema=[ + Field(name="user_id", dtype=Int64), # Entity field + Field(name="feature1", dtype=String), + Field(name="embedding", dtype=Float32), + ], + source=source, + ) + + @pytest.fixture + def mock_successful_response(self): + """Create a mock successful HTTP response for documents retrieval.""" + return { + "metadata": { + "feature_names": ["feature1", "embedding", "distance", "user_id"] + }, + "results": [ + { + "values": ["test_value_1", "test_value_2"], + "statuses": ["PRESENT", "PRESENT"] + }, # feature1 + { + "values": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + "statuses": ["PRESENT", "PRESENT"], + "event_timestamps": ["2023-01-01T00:00:00Z", "2023-01-01T01:00:00Z"] + }, # embedding + { + "values": [0.85, 0.92], + "statuses": ["PRESENT", "PRESENT"] + }, # distance + { + "values": [123, 456], + "statuses": ["PRESENT", "PRESENT"] + } # user_id + ] + } + + @pytest.fixture + def mock_successful_response_v2(self): + """Create a mock successful HTTP response for documents retrieval v2.""" + return { + "metadata": { + "feature_names": ["user_id", "feature1"] + }, + "results": [ + { + "values": [123, 456], + "statuses": ["PRESENT", "PRESENT"] + }, # user_id + { + "values": ["test_value_1", "test_value_2"], + "statuses": ["PRESENT", "PRESENT"], + "event_timestamps": ["2023-01-01T00:00:00Z", "2023-01-01T01:00:00Z"] + } # feature1 + ] + } + + @patch('feast.infra.online_stores.remote.get_remote_online_documents') + def test_retrieve_online_documents_success( + self, + mock_get_remote_online_documents, + remote_store, + config, + feature_view, + mock_successful_response + ): + """Test successful retrieve_online_documents call.""" + # Setup mock response + mock_response = Mock() + mock_response.status_code = 200 + mock_response.text = json.dumps(mock_successful_response) + mock_get_remote_online_documents.return_value = mock_response + + # Call the method + result = remote_store.retrieve_online_documents( + config=config, + table=feature_view, + requested_features=["feature1"], + embedding=[0.1, 0.2, 0.3], + top_k=2, + distance_metric="L2" + ) + + # Verify the call was made correctly + mock_get_remote_online_documents.assert_called_once() + call_args = mock_get_remote_online_documents.call_args + assert call_args[1]['config'] == config + + # Parse the request body to verify it's correct + req_body = json.loads(call_args[1]['req_body']) + assert req_body['features'] == ['test_feature_view:feature1'] + assert req_body['query'] == [0.1, 0.2, 0.3] + assert req_body['top_k'] == 2 + assert req_body['distance_metric'] == "L2" + + # Verify the result + assert len(result) == 2 + event_ts, entity_key_proto, feature_val, vector_value, distance_val = result[0] + + # Check event timestamp + assert isinstance(event_ts, datetime) + + # Check that we got ValueProto objects + assert isinstance(feature_val, ValueProto) + assert isinstance(vector_value, ValueProto) + assert isinstance(distance_val, ValueProto) + + @patch('feast.infra.online_stores.remote.get_remote_online_documents') + def test_retrieve_online_documents_v2_success( + self, + mock_get_remote_online_documents, + remote_store, + config, + feature_view, + mock_successful_response_v2 + ): + """Test successful retrieve_online_documents_v2 call.""" + # Setup mock response + mock_response = Mock() + mock_response.status_code = 200 + mock_response.text = json.dumps(mock_successful_response_v2) + mock_get_remote_online_documents.return_value = mock_response + + # Call the method + result = remote_store.retrieve_online_documents_v2( + config=config, + table=feature_view, + requested_features=["feature1"], + embedding=[0.1, 0.2, 0.3], + top_k=2, + distance_metric="cosine", + query_string="test query" + ) + + # Verify the call was made correctly + mock_get_remote_online_documents.assert_called_once() + call_args = mock_get_remote_online_documents.call_args + assert call_args[1]['config'] == config + + # Parse the request body to verify it's correct + req_body = json.loads(call_args[1]['req_body']) + assert req_body['features'] == ['test_feature_view:feature1'] + assert req_body['query'] == [0.1, 0.2, 0.3] + assert req_body['top_k'] == 2 + assert req_body['distance_metric'] == "cosine" + assert req_body['query_string'] == "test query" + assert req_body['api_version'] == 2 + + # Verify the result + assert len(result) == 2 + event_ts, entity_key_proto, feature_values_dict = result[0] + + # Check event timestamp + assert isinstance(event_ts, datetime) + + # Check entity key proto + assert isinstance(entity_key_proto, EntityKeyProto) + + # Check feature values dictionary + assert isinstance(feature_values_dict, dict) + assert "feature1" in feature_values_dict + assert isinstance(feature_values_dict["feature1"], ValueProto) + + @patch('feast.infra.online_stores.remote.get_remote_online_documents') + def test_retrieve_online_documents_with_cert( + self, + mock_get_remote_online_documents, + remote_store, + config_with_cert, + feature_view, + mock_successful_response + ): + """Test retrieve_online_documents with TLS certificate.""" + # Setup mock response + mock_response = Mock() + mock_response.status_code = 200 + mock_response.text = json.dumps(mock_successful_response) + mock_get_remote_online_documents.return_value = mock_response + + # Call the method + result = remote_store.retrieve_online_documents( + config=config_with_cert, + table=feature_view, + requested_features=["feature1"], + embedding=[0.1, 0.2, 0.3], + top_k=1 + ) + + # Verify the call was made + mock_get_remote_online_documents.assert_called_once() + assert len(result) == 2 + + @patch('feast.infra.online_stores.remote.get_remote_online_documents') + def test_retrieve_online_documents_error_response( + self, + mock_get_remote_online_documents, + remote_store, + config, + feature_view + ): + """Test retrieve_online_documents with error response.""" + # Setup mock error response + mock_response = Mock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + mock_get_remote_online_documents.return_value = mock_response + + # Call the method and expect RuntimeError + with pytest.raises(RuntimeError, match="Unable to retrieve the online documents using feature server API"): + remote_store.retrieve_online_documents( + config=config, + table=feature_view, + requested_features=["feature1"], + embedding=[0.1, 0.2, 0.3], + top_k=1 + ) + + @patch('feast.infra.online_stores.remote.get_remote_online_documents') + def test_retrieve_online_documents_v2_error_response( + self, + mock_get_remote_online_documents, + remote_store, + config, + feature_view + ): + """Test retrieve_online_documents_v2 with error response.""" + # Setup mock error response + mock_response = Mock() + mock_response.status_code = 404 + mock_response.text = "Not Found" + mock_get_remote_online_documents.return_value = mock_response + + # Call the method and expect RuntimeError + with pytest.raises(RuntimeError, match="Unable to retrieve the online documents using feature server API"): + remote_store.retrieve_online_documents_v2( + config=config, + table=feature_view, + requested_features=["feature1"], + embedding=[0.1, 0.2, 0.3], + top_k=1 + ) + + def test_construct_online_documents_api_json_request(self, remote_store, feature_view): + """Test _construct_online_documents_api_json_request method.""" + result = remote_store._construct_online_documents_api_json_request( + table=feature_view, + requested_features=["feature1", "feature2"], + embedding=[0.1, 0.2, 0.3], + top_k=5, + distance_metric="cosine" + ) + + parsed_result = json.loads(result) + assert parsed_result["features"] == ["test_feature_view:feature1", "test_feature_view:feature2"] + assert parsed_result["query"] == [0.1, 0.2, 0.3] + assert parsed_result["top_k"] == 5 + assert parsed_result["distance_metric"] == "cosine" + + def test_construct_online_documents_v2_api_json_request(self, remote_store, feature_view): + """Test _construct_online_documents_v2_api_json_request method.""" + result = remote_store._construct_online_documents_v2_api_json_request( + table=feature_view, + requested_features=["feature1"], + embedding=[0.1, 0.2], + top_k=3, + distance_metric="L2", + query_string="test query", + api_version=2 + ) + + parsed_result = json.loads(result) + assert parsed_result["features"] == ["test_feature_view:feature1"] + assert parsed_result["query"] == [0.1, 0.2] + assert parsed_result["top_k"] == 3 + assert parsed_result["distance_metric"] == "L2" + assert parsed_result["query_string"] == "test query" + assert parsed_result["api_version"] == 2 + + + def test_extract_requested_feature_value(self, remote_store): + """Test _extract_requested_feature_value helper method.""" + response_json = { + "results": [ + { + "values": ["test_value"], + "statuses": ["PRESENT"] + } + ] + } + feature_name_to_index = {"feature1": 0} + + result = remote_store._extract_requested_feature_value( + response_json, feature_name_to_index, ["feature1"], 0 + ) + assert isinstance(result, ValueProto) + + def test_is_feature_present(self, remote_store): + """Test _is_feature_present helper method.""" + response_json = { + "results": [ + { + "statuses": ["PRESENT", "NOT_FOUND"] + } + ] + } + + assert remote_store._is_feature_present(response_json, 0, 0) == True + assert remote_store._is_feature_present(response_json, 0, 1) == False \ No newline at end of file From dfafa56b8e2514563a8b2119a37060c336253794 Mon Sep 17 00:00:00 2001 From: jyejare Date: Tue, 1 Jul 2025 13:16:17 +0530 Subject: [PATCH 6/6] Fixed checks and comments Signed-off-by: jyejare --- sdk/python/feast/feature_server.py | 16 +- sdk/python/feast/feature_store.py | 3 +- .../feast/infra/online_stores/remote.py | 145 +++++++---- sdk/python/feast/utils.py | 2 +- .../online_store/test_remote_online_store.py | 225 +++++++++--------- 5 files changed, 218 insertions(+), 173 deletions(-) diff --git a/sdk/python/feast/feature_server.py b/sdk/python/feast/feature_server.py index 9c1613a46c4..8593512d5c6 100644 --- a/sdk/python/feast/feature_server.py +++ b/sdk/python/feast/feature_server.py @@ -6,7 +6,7 @@ import traceback from contextlib import asynccontextmanager from importlib import resources as importlib_resources -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import pandas as pd import psutil @@ -86,13 +86,13 @@ class MaterializeIncrementalRequest(BaseModel): class GetOnlineFeaturesRequest(BaseModel): entities: Dict[str, List[Any]] feature_service: Optional[str] = None - features: Optional[List[str]] = None + features: List[str] = [] full_feature_names: bool = False class GetOnlineDocumentsRequest(BaseModel): feature_service: Optional[str] = None - features: Optional[List[str]] = None + features: List[str] = [] full_feature_names: bool = False top_k: Optional[int] = None query: Optional[List[float]] = None @@ -119,7 +119,7 @@ class SaveDocumentRequest(BaseModel): def _get_features( - request: GetOnlineFeaturesRequest | GetOnlineDocumentsRequest, + request: Union[GetOnlineFeaturesRequest, GetOnlineDocumentsRequest], store: "feast.FeatureStore", ): if request.feature_service: @@ -265,13 +265,9 @@ async def retrieve_online_documents( # Initialize parameters for FeatureStore.retrieve_online_documents_v2(...) call features = await run_in_threadpool(_get_features, request, store) - read_params = dict( - features=features, - query=request.query, - top_k=request.top_k - ) + read_params = dict(features=features, query=request.query, top_k=request.top_k) if request.api_version == 2 and request.query_string is not None: - read_params['query_string'] = request.query_string + read_params["query_string"] = request.query_string if request.api_version == 2: response = await run_in_threadpool( diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 40acb855854..cfad8178c05 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -2275,7 +2275,7 @@ def retrieve_online_documents_v2( distance_metric, query_string, ) - + def _retrieve_from_online_store( self, provider: Provider, @@ -2418,7 +2418,6 @@ def _retrieve_from_online_store_v2( data=entity_key_dict, ) - return OnlineResponse(online_features_response) def serve( diff --git a/sdk/python/feast/infra/online_stores/remote.py b/sdk/python/feast/infra/online_stores/remote.py index 026bf86975e..ec2b05759ba 100644 --- a/sdk/python/feast/infra/online_stores/remote.py +++ b/sdk/python/feast/infra/online_stores/remote.py @@ -179,7 +179,15 @@ def retrieve_online_documents( embedding: Optional[List[float]], top_k: int, distance_metric: Optional[str] = "L2", - ) -> List[Tuple[Optional[datetime], Optional[EntityKeyProto], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto]]]: + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[ValueProto], + Optional[ValueProto], + Optional[ValueProto], + ] + ]: assert isinstance(config.online_store, RemoteOnlineStoreConfig) config.online_store.__class__ = RemoteOnlineStoreConfig @@ -190,18 +198,16 @@ def retrieve_online_documents( if response.status_code == 200: logger.debug("Able to retrieve the online documents from feature server.") response_json = json.loads(response.text) - event_ts = self._get_event_ts(response_json) + event_ts: Optional[datetime] = self._get_event_ts(response_json) # Create feature name to index mapping for efficient lookup feature_name_to_index = { - name: idx for idx, name in enumerate(response_json["metadata"]["feature_names"]) + name: idx + for idx, name in enumerate(response_json["metadata"]["feature_names"]) } vector_field_metadata = _get_feature_view_vector_field_metadata(table) - # Extract feature names once - feature_names = response_json["metadata"]["feature_names"] - # Process each result row num_results = len(response_json["results"][0]["values"]) result_tuples = [] @@ -215,13 +221,21 @@ def retrieve_online_documents( response_json, feature_name_to_index, vector_field_metadata, row_idx ) distance_val = self._extract_distance_value( - response_json, feature_name_to_index, 'distance', row_idx + response_json, feature_name_to_index, "distance", row_idx ) entity_key_proto = self._construct_entity_key_from_response( - response_json, row_idx, feature_name_to_index + response_json, row_idx, feature_name_to_index, table ) - result_tuples.append((event_ts, entity_key_proto, feature_val, vector_value, distance_val)) + result_tuples.append( + ( + event_ts, + entity_key_proto, + feature_val, + vector_value, + distance_val, + ) + ) return result_tuples else: @@ -238,50 +252,77 @@ def retrieve_online_documents_v2( top_k: int, distance_metric: Optional[str] = None, query_string: Optional[str] = None, - ) -> List[Tuple[Optional[datetime], Optional[EntityKeyProto], Optional[Dict[str, ValueProto]]]]: + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: assert isinstance(config.online_store, RemoteOnlineStoreConfig) config.online_store.__class__ = RemoteOnlineStoreConfig req_body = self._construct_online_documents_v2_api_json_request( - table, requested_features, embedding, top_k, distance_metric, query_string, api_version=2 + table, + requested_features, + embedding, + top_k, + distance_metric, + query_string, + api_version=2, ) response = get_remote_online_documents(config=config, req_body=req_body) if response.status_code == 200: logger.debug("Able to retrieve the online documents from feature server.") response_json = json.loads(response.text) - event_ts = self._get_event_ts(response_json) - + event_ts: Optional[datetime] = self._get_event_ts(response_json) + # Create feature name to index mapping for efficient lookup feature_name_to_index = { - name: idx for idx, name in enumerate(response_json["metadata"]["feature_names"]) + name: idx + for idx, name in enumerate(response_json["metadata"]["feature_names"]) } # Process each result row - num_results = len(response_json["results"][0]["values"]) if response_json["results"] else 0 + num_results = ( + len(response_json["results"][0]["values"]) + if response_json["results"] + else 0 + ) result_tuples = [] for row_idx in range(num_results): # Build feature values dictionary for requested features - feature_values_dict: Dict[str, ValueProto] = {} - + feature_values_dict = {} + if requested_features: for feature_name in requested_features: if feature_name in feature_name_to_index: feature_idx = feature_name_to_index[feature_name] - if self._is_feature_present(response_json, feature_idx, row_idx): - feature_values_dict[feature_name] = self._extract_feature_value( - response_json, feature_idx, row_idx + if self._is_feature_present( + response_json, feature_idx, row_idx + ): + feature_values_dict[feature_name] = ( + self._extract_feature_value( + response_json, feature_idx, row_idx + ) ) else: feature_values_dict[feature_name] = ValueProto() # Construct entity key proto using existing helper method entity_key_proto = self._construct_entity_key_from_response( - response_json, row_idx, feature_name_to_index + response_json, row_idx, feature_name_to_index, table + ) + + result_tuples.append( + ( + event_ts, + entity_key_proto, + feature_values_dict if feature_values_dict else None, + ) ) - result_tuples.append((event_ts, entity_key_proto, feature_values_dict)) - return result_tuples else: error_msg = f"Unable to retrieve the online documents using feature server API. Error_code={response.status_code}, error_message={response.text}" @@ -293,8 +334,8 @@ def _extract_requested_feature_value( response_json: dict, feature_name_to_index: dict, requested_features: Optional[List[str]], - row_idx: int - ) -> ValueProto: + row_idx: int, + ) -> Optional[ValueProto]: """Extract the first available requested feature value.""" if not requested_features: return ValueProto() @@ -303,7 +344,9 @@ def _extract_requested_feature_value( if feature_name in feature_name_to_index: feature_idx = feature_name_to_index[feature_name] if self._is_feature_present(response_json, feature_idx, row_idx): - return self._extract_feature_value(response_json, feature_idx, row_idx) + return self._extract_feature_value( + response_json, feature_idx, row_idx + ) return ValueProto() @@ -312,15 +355,20 @@ def _extract_vector_field_value( response_json: dict, feature_name_to_index: dict, vector_field_metadata, - row_idx: int - ) -> ValueProto: + row_idx: int, + ) -> Optional[ValueProto]: """Extract vector field value from response.""" - if not vector_field_metadata or vector_field_metadata.name not in feature_name_to_index: + if ( + not vector_field_metadata + or vector_field_metadata.name not in feature_name_to_index + ): return ValueProto() vector_feature_idx = feature_name_to_index[vector_field_metadata.name] if self._is_feature_present(response_json, vector_feature_idx, row_idx): - return self._extract_feature_value(response_json, vector_feature_idx, row_idx) + return self._extract_feature_value( + response_json, vector_feature_idx, row_idx + ) return ValueProto() @@ -329,22 +377,26 @@ def _extract_distance_value( response_json: dict, feature_name_to_index: dict, distance_feature_name: str, - row_idx: int - ) -> ValueProto: + row_idx: int, + ) -> Optional[ValueProto]: """Extract distance/score value from response.""" if not distance_feature_name: return ValueProto() distance_feature_idx = feature_name_to_index[distance_feature_name] if self._is_feature_present(response_json, distance_feature_idx, row_idx): - distance_value = response_json["results"][distance_feature_idx]["values"][row_idx] + distance_value = response_json["results"][distance_feature_idx]["values"][ + row_idx + ] distance_val = ValueProto() distance_val.float_val = float(distance_value) return distance_val return ValueProto() - def _is_feature_present(self, response_json: dict, feature_idx: int, row_idx: int) -> bool: + def _is_feature_present( + self, response_json: dict, feature_idx: int, row_idx: int + ) -> bool: """Check if a feature is present in the response.""" return response_json["results"][feature_idx]["statuses"][row_idx] == "PRESENT" @@ -406,7 +458,7 @@ def _construct_online_documents_v2_api_json_request( top_k: int, distance_metric: Optional[str] = None, query_string: Optional[str] = None, - api_version: Optional[int] = 1, + api_version: Optional[int] = 2, ) -> str: api_requested_features = [] if requested_features is not None: @@ -432,12 +484,19 @@ def _get_event_ts(self, response_json) -> datetime: return datetime.fromisoformat(event_ts.replace("Z", "+00:00")) def _construct_entity_key_from_response( - self, response_json: dict, row_idx: int, feature_name_to_index: dict + self, + response_json: dict, + row_idx: int, + feature_name_to_index: dict, + table: FeatureView, ) -> Optional[EntityKeyProto]: """Construct EntityKeyProto from response data.""" - # Look for entity key fields in the response - entity_fields = [name for name in feature_name_to_index.keys() - if name.endswith('_id') or name in ['id', 'key', 'entity_id']] + # Use the feature view's join_keys to identify entity fields + entity_fields = [ + join_key + for join_key in table.join_keys + if join_key in feature_name_to_index + ] if not entity_fields: return None @@ -449,12 +508,16 @@ def _construct_entity_key_from_response( if entity_field in feature_name_to_index: feature_idx = feature_name_to_index[entity_field] if self._is_feature_present(response_json, feature_idx, row_idx): - entity_value = self._extract_feature_value(response_json, feature_idx, row_idx) + entity_value = self._extract_feature_value( + response_json, feature_idx, row_idx + ) entity_key_proto.entity_values.append(entity_value) return entity_key_proto if entity_key_proto.entity_values else None - def _extract_feature_value(self, response_json: dict, feature_idx: int, row_idx: int) -> ValueProto: + def _extract_feature_value( + self, response_json: dict, feature_idx: int, row_idx: int + ) -> ValueProto: """Extract and convert a feature value to ValueProto.""" raw_value = response_json["results"][feature_idx]["values"][row_idx] if raw_value is None: diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index c63dad6a6ab..1f629f61f21 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -1048,7 +1048,7 @@ def _list_feature_views( def _get_feature_views_to_use( registry: "BaseRegistry", project, - features: Optional[Union[List[str], "FeatureService"]], + features: Union[List[str], "FeatureService"], allow_cache=False, hide_dummy_entity: bool = True, ) -> Tuple[List["FeatureView"], List["OnDemandFeatureView"]]: diff --git a/sdk/python/tests/unit/infra/online_store/test_remote_online_store.py b/sdk/python/tests/unit/infra/online_store/test_remote_online_store.py index f8b59db79bd..1c074a40d40 100644 --- a/sdk/python/tests/unit/infra/online_store/test_remote_online_store.py +++ b/sdk/python/tests/unit/infra/online_store/test_remote_online_store.py @@ -1,14 +1,14 @@ import json -import pytest from datetime import datetime, timedelta -from unittest.mock import Mock, patch, MagicMock -from typing import List, Optional +from unittest.mock import Mock, patch + +import pytest from feast import Entity, FeatureView, Field, FileSource, RepoConfig from feast.infra.online_stores.remote import RemoteOnlineStore, RemoteOnlineStoreConfig from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto -from feast.types import Float32, String, Int64 +from feast.types import Float32, Int64, String from feast.value_type import ValueType @@ -26,10 +26,9 @@ def config(self): return RepoConfig( project="test_project", online_store=RemoteOnlineStoreConfig( - type="remote", - path="http://localhost:6566" + type="remote", path="http://localhost:6566" ), - registry="dummy_registry" + registry="dummy_registry", ) @pytest.fixture @@ -38,21 +37,18 @@ def config_with_cert(self): return RepoConfig( project="test_project", online_store=RemoteOnlineStoreConfig( - type="remote", - path="http://localhost:6566", - cert="/path/to/cert.pem" + type="remote", path="http://localhost:6566", cert="/path/to/cert.pem" ), - registry="dummy_registry" + registry="dummy_registry", ) @pytest.fixture def feature_view(self): """Create a test FeatureView.""" - entity = Entity(name="user_id", description="User ID", value_type=ValueType.INT64) - source = FileSource( - path="test.parquet", - timestamp_field="event_timestamp" + entity = Entity( + name="user_id", description="User ID", value_type=ValueType.INT64 ) + source = FileSource(path="test.parquet", timestamp_field="event_timestamp") return FeatureView( name="test_feature_view", entities=[entity], @@ -75,52 +71,50 @@ def mock_successful_response(self): "results": [ { "values": ["test_value_1", "test_value_2"], - "statuses": ["PRESENT", "PRESENT"] + "statuses": ["PRESENT", "PRESENT"], }, # feature1 { "values": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], "statuses": ["PRESENT", "PRESENT"], - "event_timestamps": ["2023-01-01T00:00:00Z", "2023-01-01T01:00:00Z"] + "event_timestamps": [ + "2023-01-01T00:00:00Z", + "2023-01-01T01:00:00Z", + ], }, # embedding { "values": [0.85, 0.92], - "statuses": ["PRESENT", "PRESENT"] + "statuses": ["PRESENT", "PRESENT"], }, # distance - { - "values": [123, 456], - "statuses": ["PRESENT", "PRESENT"] - } # user_id - ] + {"values": [123, 456], "statuses": ["PRESENT", "PRESENT"]}, # user_id + ], } @pytest.fixture def mock_successful_response_v2(self): """Create a mock successful HTTP response for documents retrieval v2.""" return { - "metadata": { - "feature_names": ["user_id", "feature1"] - }, + "metadata": {"feature_names": ["user_id", "feature1"]}, "results": [ - { - "values": [123, 456], - "statuses": ["PRESENT", "PRESENT"] - }, # user_id + {"values": [123, 456], "statuses": ["PRESENT", "PRESENT"]}, # user_id { "values": ["test_value_1", "test_value_2"], "statuses": ["PRESENT", "PRESENT"], - "event_timestamps": ["2023-01-01T00:00:00Z", "2023-01-01T01:00:00Z"] - } # feature1 - ] + "event_timestamps": [ + "2023-01-01T00:00:00Z", + "2023-01-01T01:00:00Z", + ], + }, # feature1 + ], } - @patch('feast.infra.online_stores.remote.get_remote_online_documents') + @patch("feast.infra.online_stores.remote.get_remote_online_documents") def test_retrieve_online_documents_success( - self, - mock_get_remote_online_documents, - remote_store, - config, - feature_view, - mock_successful_response + self, + mock_get_remote_online_documents, + remote_store, + config, + feature_view, + mock_successful_response, ): """Test successful retrieve_online_documents call.""" # Setup mock response @@ -136,41 +130,41 @@ def test_retrieve_online_documents_success( requested_features=["feature1"], embedding=[0.1, 0.2, 0.3], top_k=2, - distance_metric="L2" + distance_metric="L2", ) # Verify the call was made correctly mock_get_remote_online_documents.assert_called_once() call_args = mock_get_remote_online_documents.call_args - assert call_args[1]['config'] == config - + assert call_args[1]["config"] == config + # Parse the request body to verify it's correct - req_body = json.loads(call_args[1]['req_body']) - assert req_body['features'] == ['test_feature_view:feature1'] - assert req_body['query'] == [0.1, 0.2, 0.3] - assert req_body['top_k'] == 2 - assert req_body['distance_metric'] == "L2" + req_body = json.loads(call_args[1]["req_body"]) + assert req_body["features"] == ["test_feature_view:feature1"] + assert req_body["query"] == [0.1, 0.2, 0.3] + assert req_body["top_k"] == 2 + assert req_body["distance_metric"] == "L2" # Verify the result assert len(result) == 2 event_ts, entity_key_proto, feature_val, vector_value, distance_val = result[0] - + # Check event timestamp assert isinstance(event_ts, datetime) - + # Check that we got ValueProto objects assert isinstance(feature_val, ValueProto) assert isinstance(vector_value, ValueProto) assert isinstance(distance_val, ValueProto) - @patch('feast.infra.online_stores.remote.get_remote_online_documents') + @patch("feast.infra.online_stores.remote.get_remote_online_documents") def test_retrieve_online_documents_v2_success( - self, - mock_get_remote_online_documents, - remote_store, - config, - feature_view, - mock_successful_response_v2 + self, + mock_get_remote_online_documents, + remote_store, + config, + feature_view, + mock_successful_response_v2, ): """Test successful retrieve_online_documents_v2 call.""" # Setup mock response @@ -187,46 +181,46 @@ def test_retrieve_online_documents_v2_success( embedding=[0.1, 0.2, 0.3], top_k=2, distance_metric="cosine", - query_string="test query" + query_string="test query", ) # Verify the call was made correctly mock_get_remote_online_documents.assert_called_once() call_args = mock_get_remote_online_documents.call_args - assert call_args[1]['config'] == config - + assert call_args[1]["config"] == config + # Parse the request body to verify it's correct - req_body = json.loads(call_args[1]['req_body']) - assert req_body['features'] == ['test_feature_view:feature1'] - assert req_body['query'] == [0.1, 0.2, 0.3] - assert req_body['top_k'] == 2 - assert req_body['distance_metric'] == "cosine" - assert req_body['query_string'] == "test query" - assert req_body['api_version'] == 2 + req_body = json.loads(call_args[1]["req_body"]) + assert req_body["features"] == ["test_feature_view:feature1"] + assert req_body["query"] == [0.1, 0.2, 0.3] + assert req_body["top_k"] == 2 + assert req_body["distance_metric"] == "cosine" + assert req_body["query_string"] == "test query" + assert req_body["api_version"] == 2 # Verify the result assert len(result) == 2 event_ts, entity_key_proto, feature_values_dict = result[0] - + # Check event timestamp assert isinstance(event_ts, datetime) - + # Check entity key proto assert isinstance(entity_key_proto, EntityKeyProto) - + # Check feature values dictionary assert isinstance(feature_values_dict, dict) assert "feature1" in feature_values_dict assert isinstance(feature_values_dict["feature1"], ValueProto) - @patch('feast.infra.online_stores.remote.get_remote_online_documents') + @patch("feast.infra.online_stores.remote.get_remote_online_documents") def test_retrieve_online_documents_with_cert( - self, - mock_get_remote_online_documents, - remote_store, - config_with_cert, - feature_view, - mock_successful_response + self, + mock_get_remote_online_documents, + remote_store, + config_with_cert, + feature_view, + mock_successful_response, ): """Test retrieve_online_documents with TLS certificate.""" # Setup mock response @@ -241,20 +235,16 @@ def test_retrieve_online_documents_with_cert( table=feature_view, requested_features=["feature1"], embedding=[0.1, 0.2, 0.3], - top_k=1 + top_k=1, ) # Verify the call was made mock_get_remote_online_documents.assert_called_once() assert len(result) == 2 - @patch('feast.infra.online_stores.remote.get_remote_online_documents') + @patch("feast.infra.online_stores.remote.get_remote_online_documents") def test_retrieve_online_documents_error_response( - self, - mock_get_remote_online_documents, - remote_store, - config, - feature_view + self, mock_get_remote_online_documents, remote_store, config, feature_view ): """Test retrieve_online_documents with error response.""" # Setup mock error response @@ -264,22 +254,21 @@ def test_retrieve_online_documents_error_response( mock_get_remote_online_documents.return_value = mock_response # Call the method and expect RuntimeError - with pytest.raises(RuntimeError, match="Unable to retrieve the online documents using feature server API"): + with pytest.raises( + RuntimeError, + match="Unable to retrieve the online documents using feature server API", + ): remote_store.retrieve_online_documents( config=config, table=feature_view, requested_features=["feature1"], embedding=[0.1, 0.2, 0.3], - top_k=1 + top_k=1, ) - @patch('feast.infra.online_stores.remote.get_remote_online_documents') + @patch("feast.infra.online_stores.remote.get_remote_online_documents") def test_retrieve_online_documents_v2_error_response( - self, - mock_get_remote_online_documents, - remote_store, - config, - feature_view + self, mock_get_remote_online_documents, remote_store, config, feature_view ): """Test retrieve_online_documents_v2 with error response.""" # Setup mock error response @@ -289,32 +278,42 @@ def test_retrieve_online_documents_v2_error_response( mock_get_remote_online_documents.return_value = mock_response # Call the method and expect RuntimeError - with pytest.raises(RuntimeError, match="Unable to retrieve the online documents using feature server API"): + with pytest.raises( + RuntimeError, + match="Unable to retrieve the online documents using feature server API", + ): remote_store.retrieve_online_documents_v2( config=config, table=feature_view, requested_features=["feature1"], embedding=[0.1, 0.2, 0.3], - top_k=1 + top_k=1, ) - def test_construct_online_documents_api_json_request(self, remote_store, feature_view): + def test_construct_online_documents_api_json_request( + self, remote_store, feature_view + ): """Test _construct_online_documents_api_json_request method.""" result = remote_store._construct_online_documents_api_json_request( table=feature_view, requested_features=["feature1", "feature2"], embedding=[0.1, 0.2, 0.3], top_k=5, - distance_metric="cosine" + distance_metric="cosine", ) - + parsed_result = json.loads(result) - assert parsed_result["features"] == ["test_feature_view:feature1", "test_feature_view:feature2"] + assert parsed_result["features"] == [ + "test_feature_view:feature1", + "test_feature_view:feature2", + ] assert parsed_result["query"] == [0.1, 0.2, 0.3] assert parsed_result["top_k"] == 5 assert parsed_result["distance_metric"] == "cosine" - def test_construct_online_documents_v2_api_json_request(self, remote_store, feature_view): + def test_construct_online_documents_v2_api_json_request( + self, remote_store, feature_view + ): """Test _construct_online_documents_v2_api_json_request method.""" result = remote_store._construct_online_documents_v2_api_json_request( table=feature_view, @@ -323,9 +322,9 @@ def test_construct_online_documents_v2_api_json_request(self, remote_store, feat top_k=3, distance_metric="L2", query_string="test query", - api_version=2 + api_version=2, ) - + parsed_result = json.loads(result) assert parsed_result["features"] == ["test_feature_view:feature1"] assert parsed_result["query"] == [0.1, 0.2] @@ -334,19 +333,13 @@ def test_construct_online_documents_v2_api_json_request(self, remote_store, feat assert parsed_result["query_string"] == "test query" assert parsed_result["api_version"] == 2 - def test_extract_requested_feature_value(self, remote_store): """Test _extract_requested_feature_value helper method.""" response_json = { - "results": [ - { - "values": ["test_value"], - "statuses": ["PRESENT"] - } - ] + "results": [{"values": ["test_value"], "statuses": ["PRESENT"]}] } feature_name_to_index = {"feature1": 0} - + result = remote_store._extract_requested_feature_value( response_json, feature_name_to_index, ["feature1"], 0 ) @@ -354,13 +347,7 @@ def test_extract_requested_feature_value(self, remote_store): def test_is_feature_present(self, remote_store): """Test _is_feature_present helper method.""" - response_json = { - "results": [ - { - "statuses": ["PRESENT", "NOT_FOUND"] - } - ] - } - - assert remote_store._is_feature_present(response_json, 0, 0) == True - assert remote_store._is_feature_present(response_json, 0, 1) == False \ No newline at end of file + response_json = {"results": [{"statuses": ["PRESENT", "NOT_FOUND"]}]} + + assert remote_store._is_feature_present(response_json, 0, 0) + assert not remote_store._is_feature_present(response_json, 0, 1)