From cb4c64b2cc6a3da8c5adcb9c57149bfe798b5338 Mon Sep 17 00:00:00 2001 From: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> Date: Sun, 30 Mar 2025 22:11:22 +0200 Subject: [PATCH 1/4] feat: Add support for hybrid search with vector and text queries in Milvus Signed-off-by: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> Signed-off-by: yassinnouh21 --- .../milvus_online_store/milvus.py | 265 ++++++++++++------ 1 file changed, 178 insertions(+), 87 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py index 99405dc2ca0..38de3d32f52 100644 --- a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py +++ b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py @@ -95,6 +95,7 @@ class MilvusOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig): metric_type: Optional[str] = "COSINE" embedding_dim: Optional[int] = 128 vector_enabled: Optional[bool] = True + text_search_enabled: Optional[bool] = False nlist: Optional[int] = 128 username: Optional[StrictStr] = "" password: Optional[StrictStr] = "" @@ -113,8 +114,8 @@ class MilvusOnlineStore(OnlineStore): def _get_db_path(self, config: RepoConfig) -> str: assert ( - config.online_store.type == "milvus" - or config.online_store.type.endswith("MilvusOnlineStore") + config.online_store.type == "milvus" + or config.online_store.type.endswith("MilvusOnlineStore") ) if config.repo_path and not Path(config.online_store.path).is_absolute(): @@ -139,7 +140,7 @@ def _connect(self, config: RepoConfig) -> MilvusClient: return self.client def _get_or_create_collection( - self, config: RepoConfig, table: FeatureView + self, config: RepoConfig, table: FeatureView ) -> Dict[str, Any]: self.client = self._connect(config) vector_field_dict = {k.name: k for k in table.schema if k.vector_index} @@ -198,12 +199,12 @@ def _get_or_create_collection( index_params = self.client.prepare_index_params() for vector_field in schema.fields: if ( - vector_field.dtype - in [ - DataType.FLOAT_VECTOR, - DataType.BINARY_VECTOR, - ] - and vector_field.name in vector_field_dict + vector_field.dtype + in [ + DataType.FLOAT_VECTOR, + DataType.BINARY_VECTOR, + ] + and vector_field.name in vector_field_dict ): metric = vector_field_dict[ vector_field.name @@ -228,18 +229,18 @@ def _get_or_create_collection( return self._collections[collection_name] def online_write_batch( - self, - config: RepoConfig, - table: FeatureView, - data: List[ - Tuple[ - EntityKeyProto, - Dict[str, ValueProto], - datetime, - Optional[datetime], - ] - ], - progress: Optional[Callable[[int], Any]], + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[ + EntityKeyProto, + Dict[str, ValueProto], + datetime, + Optional[datetime], + ] + ], + progress: Optional[Callable[[int], Any]], ) -> None: self.client = self._connect(config) collection = self._get_or_create_collection(config, table) @@ -286,8 +287,8 @@ def online_write_batch( single_entity_record[field] = "" # Store only the latest event timestamp per entity if ( - entity_key_str not in unique_entities - or unique_entities[entity_key_str]["event_ts"] < timestamp_int + entity_key_str not in unique_entities + or unique_entities[entity_key_str]["event_ts"] < timestamp_int ): unique_entities[entity_key_str] = single_entity_record @@ -301,12 +302,12 @@ def online_write_batch( ) def online_read( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, - full_feature_names: bool = False, + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, + full_feature_names: bool = False, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: self.client = self._connect(config) collection_name = _table_id(config.project, table) @@ -315,9 +316,9 @@ def online_read( composite_key_name = _get_composite_key_name(table) output_fields = ( - [composite_key_name] - + (requested_features if requested_features else []) - + ["created_ts", "event_ts"] + [composite_key_name] + + (requested_features if requested_features else []) + + ["created_ts", "event_ts"] ) assert all( field in [f["name"] for f in collection["fields"]] @@ -334,9 +335,9 @@ def online_read( composite_entities.append(entity_key_str) query_filter_for_entities = ( - f"{composite_key_name} in [" - + ", ".join([f"'{e}'" for e in composite_entities]) - + "]" + f"{composite_key_name} in [" + + ", ".join([f"'{e}'" for e in composite_entities]) + + "]" ) self.client.load_collection(collection_name) results = self.client.query( @@ -440,13 +441,13 @@ def online_read( return result_list def update( - self, - config: RepoConfig, - tables_to_delete: Sequence[FeatureView], - tables_to_keep: Sequence[FeatureView], - entities_to_delete: Sequence[Entity], - entities_to_keep: Sequence[Entity], - partial: bool, + self, + config: RepoConfig, + tables_to_delete: Sequence[FeatureView], + tables_to_keep: Sequence[FeatureView], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + partial: bool, ): self.client = self._connect(config) for table in tables_to_keep: @@ -459,15 +460,15 @@ def update( self._collections.pop(collection_name, None) def plan( - self, config: RepoConfig, desired_registry_proto: RegistryProto + self, config: RepoConfig, desired_registry_proto: RegistryProto ) -> List[InfraObject]: raise NotImplementedError def teardown( - self, - config: RepoConfig, - tables: Sequence[FeatureView], - entities: Sequence[Entity], + self, + config: RepoConfig, + tables: Sequence[FeatureView], + entities: Sequence[Entity], ): self.client = self._connect(config) for table in tables: @@ -477,14 +478,14 @@ def teardown( self._collections.pop(collection_name, None) def retrieve_online_documents_v2( - self, - config: RepoConfig, - table: FeatureView, - requested_features: List[str], - embedding: Optional[List[float]], - top_k: int, - distance_metric: Optional[str] = None, - query_string: Optional[str] = None, + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + embedding: Optional[List[float]], + top_k: int, + distance_metric: Optional[str] = None, + query_string: Optional[str] = None, ) -> List[ Tuple[ Optional[datetime], @@ -492,27 +493,40 @@ def retrieve_online_documents_v2( Optional[Dict[str, ValueProto]], ] ]: - assert embedding is not None, "Key Word Search not yet implemented for Milvus" + """ + Retrieve documents using vector similarity search or keyword search in Milvus. + + Args: + config: Feast configuration object + table: FeatureView object as the table to search + requested_features: List of requested features to retrieve + embedding: Query embedding to search for (optional) + top_k: Number of items to return + distance_metric: Distance metric to use (optional) + query_string: The query string to search for using keyword search (optional) + + Returns: + List of tuples containing the event timestamp, entity key, and feature values + """ entity_name_feast_primitive_type_map = { k.name: k.dtype for k in table.entity_columns } self.client = self._connect(config) collection_name = _table_id(config.project, table) collection = self._get_or_create_collection(config, table) + if not config.online_store.vector_enabled: raise ValueError("Vector search is not enabled in the online store config") - search_params = { - "metric_type": distance_metric or config.online_store.metric_type, - "params": {"nprobe": 10}, - } + if embedding is None and query_string is None: + raise ValueError("Either embedding or query_string must be provided") composite_key_name = _get_composite_key_name(table) output_fields = ( - [composite_key_name] - + (requested_features if requested_features else []) - + ["created_ts", "event_ts"] + [composite_key_name] + + (requested_features if requested_features else []) + + ["created_ts", "event_ts"] ) assert all( field in [f["name"] for f in collection["fields"]] @@ -520,25 +534,102 @@ def retrieve_online_documents_v2( ), ( f"field(s) [{[field for field in output_fields if field not in [f['name'] for f in collection['fields']]]}] not found in collection schema" ) - # Note we choose the first vector field as the field to search on. Not ideal but it's something. + + # Find the vector search field if we need it ann_search_field = None - for field in collection["fields"]: - if ( - field["type"] in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR] - and field["name"] in output_fields - ): - ann_search_field = field["name"] - break + if embedding is not None: + for field in collection["fields"]: + if ( + field["type"] in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR] + and field["name"] in output_fields + ): + ann_search_field = field["name"] + break self.client.load_collection(collection_name) - results = self.client.search( - collection_name=collection_name, - data=[embedding], - anns_field=ann_search_field, - search_params=search_params, - limit=top_k, - output_fields=output_fields, - ) + + if embedding is not None and query_string is not None and config.online_store.vector_enabled: + string_field_list = [ + f.name for f in table.features if + isinstance(f.dtype, PrimitiveFeastType) and f.dtype.to_value_type() == ValueType.STRING + ] + + if not string_field_list: + raise ValueError("No string fields found in the feature view for text search in hybrid mode") + + # Create a filter expression for text search + filter_expressions = [] + for field in string_field_list: + if field in output_fields: + filter_expressions.append(f"{field} LIKE '%{query_string}%'") + + # Combine filter expressions with OR + filter_expr = " OR ".join(filter_expressions) if filter_expressions else "" + + # Vector search with text filter + search_params = { + "metric_type": distance_metric or config.online_store.metric_type, + "params": {"nprobe": 10}, + } + + # For hybrid search, use filter parameter instead of expr + results = self.client.search( + collection_name=collection_name, + data=[embedding], + anns_field=ann_search_field, + search_params=search_params, + limit=top_k, + output_fields=output_fields, + filter=filter_expr if filter_expr else None, + ) + + elif embedding is not None and config.online_store.vector_enabled: + # Vector search only + search_params = { + "metric_type": distance_metric or config.online_store.metric_type, + "params": {"nprobe": 10}, + } + + results = self.client.search( + collection_name=collection_name, + data=[embedding], + anns_field=ann_search_field, + search_params=search_params, + limit=top_k, + output_fields=output_fields, + ) + + elif query_string is not None: + string_field_list = [ + f.name for f in table.features if + isinstance(f.dtype, PrimitiveFeastType) and f.dtype.to_value_type() == ValueType.STRING + ] + + if not string_field_list: + raise ValueError("No string fields found in the feature view for text search") + + filter_expressions = [] + for field in string_field_list: + if field in output_fields: + filter_expressions.append(f"{field} LIKE '%{query_string}%'") + + filter_expr = " OR ".join(filter_expressions) + + if not filter_expr: + raise ValueError("No text fields found in requested features for search") + + query_results = self.client.query( + collection_name=collection_name, + filter=filter_expr, + output_fields=output_fields, + limit=top_k, + ) + + results = [[{"entity": entity, "distance": -1.0}] for entity in query_results] + else: + raise ValueError( + "Either vector_enabled must be True for embedding search or query_string must be provided for keyword search" + ) result_list = [] for hits in results: @@ -559,20 +650,20 @@ def retrieve_online_documents_v2( # entity_key_proto = None if field in ["created_ts", "event_ts"]: res_ts = datetime.fromtimestamp(field_value / 1e6) - elif field == ann_search_field: + elif field == ann_search_field and embedding is not None: serialized_embedding = _serialize_vector_to_float_list( embedding ) res[ann_search_field] = serialized_embedding elif entity_name_feast_primitive_type_map.get( - field, PrimitiveFeastType.INVALID + field, PrimitiveFeastType.INVALID ) in [ PrimitiveFeastType.STRING, PrimitiveFeastType.BYTES, ]: res[field] = ValueProto(string_val=str(field_value)) elif entity_name_feast_primitive_type_map.get( - field, PrimitiveFeastType.INVALID + field, PrimitiveFeastType.INVALID ) in [ PrimitiveFeastType.INT64, PrimitiveFeastType.INT32, @@ -603,9 +694,9 @@ def _get_composite_key_name(table: FeatureView) -> str: def _extract_proto_values_to_dict( - input_dict: Dict[str, Any], - vector_cols: List[str], - serialize_to_string=False, + input_dict: Dict[str, Any], + vector_cols: List[str], + serialize_to_string=False, ) -> Dict[str, Any]: numeric_vector_list_types = [ k @@ -633,8 +724,8 @@ def _extract_proto_values_to_dict( vector_values = getattr(feature_values, proto_val_type).val else: if ( - serialize_to_string - and proto_val_type not in ["string_val"] + numeric_types + serialize_to_string + and proto_val_type not in ["string_val"] + numeric_types ): vector_values = feature_values.SerializeToString().decode() else: From f4420b074e0b1365cdb09d1c739904628f3f0789 Mon Sep 17 00:00:00 2001 From: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> Date: Sun, 30 Mar 2025 22:13:57 +0200 Subject: [PATCH 2/4] test: Add keyword and hybrid search tests for Milvus online store Signed-off-by: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> Signed-off-by: yassinnouh21 --- .../online_store/test_online_retrieval.py | 173 ++++++++++++++++++ 1 file changed, 173 insertions(+) diff --git a/sdk/python/tests/unit/online_store/test_online_retrieval.py b/sdk/python/tests/unit/online_store/test_online_retrieval.py index e7fca47bb55..eb232518acf 100644 --- a/sdk/python/tests/unit/online_store/test_online_retrieval.py +++ b/sdk/python/tests/unit/online_store/test_online_retrieval.py @@ -1484,3 +1484,176 @@ def test_milvus_native_from_feast_data() -> None: # Clean up the collection client.drop_collection(collection_name=COLLECTION_NAME) + + +def test_milvus_keyword_search() -> None: + """ + Test retrieving documents from the Milvus online store using keyword search. + """ + random.seed(42) + n = 10 # number of samples + vector_length = 10 + runner = CliRunner() + with runner.local_repo( + example_repo_py=get_example_repo("example_rag_feature_repo.py"), + offline_store="file", + online_store="milvus", + apply=False, + teardown=False, + ) as store: + from datetime import timedelta + + from feast import Entity, FeatureView, Field, FileSource + from feast.types import Array, Float32, Int64, String, UnixTimestamp + + rag_documents_source = FileSource( + path="data/embedded_documents.parquet", + timestamp_field="event_timestamp", + created_timestamp_column="created_timestamp", + ) + + item = Entity( + name="item_id", + join_keys=["item_id"], + value_type=ValueType.INT64, + ) + author = Entity( + name="author_id", + join_keys=["author_id"], + value_type=ValueType.STRING, + ) + + document_embeddings = FeatureView( + name="text_documents", + entities=[item, author], + schema=[ + Field( + name="vector", + dtype=Array(Float32), + vector_index=True, + vector_search_metric="COSINE", + ), + Field(name="item_id", dtype=Int64), + Field(name="author_id", dtype=String), + Field(name="content", dtype=String), + Field(name="title", dtype=String), + Field(name="created_timestamp", dtype=UnixTimestamp), + Field(name="event_timestamp", dtype=UnixTimestamp), + ], + source=rag_documents_source, + ttl=timedelta(hours=24), + ) + + store.apply([rag_documents_source, item, document_embeddings]) + + # Write some data with specific text content for keyword search + document_embeddings_fv = store.get_feature_view(name="text_documents") + provider = store._get_provider() + + contents = [ + "Feast is an open source feature store for machine learning", + "Feature stores solve the problem of coordinating features for training and serving", + "Milvus is a vector database that can be used with Feast", + "Keyword search uses BM25 algorithm for relevance ranking", + "Vector search uses embeddings for semantic similarity", + "Python is a popular programming language for machine learning", + "Feast supports multiple storage backends for online and offline use cases", + "Online stores are used for low-latency feature serving", + "Offline stores are used for batch feature retrieval during training", + "Feast enables data scientists to define, manage, and share features", + ] + + titles = [ + "Introduction to Feast", + "Feature Store Benefits", + "Using Milvus with Feast", + "Keyword Search Fundamentals", + "Vector Search Overview", + "Python for ML", + "Feast Storage Options", + "Online Serving with Feast", + "Offline Training Support", + "Feast for Data Scientists", + ] + + item_keys = [ + EntityKeyProto( + join_keys=["item_id", "author_id"], + entity_values=[ + ValueProto(int64_val=i), + ValueProto(string_val=f"author_{i}"), + ], + ) + for i in range(n) + ] + data = [] + for i, item_key in enumerate(item_keys): + data.append( + ( + item_key, + { + "vector": ValueProto( + float_list_val=FloatListProto( + val=np.random.random(vector_length) + ) + ), + "content": ValueProto(string_val=contents[i]), + "title": ValueProto(string_val=titles[i]), + }, + _utc_now(), + _utc_now(), + ) + ) + + provider.online_write_batch( + config=store.config, + table=document_embeddings_fv, + data=data, + progress=None, + ) + + # Test keyword search for "Milvus" + result_milvus = store.retrieve_online_documents_v2( + features=[ + "text_documents:content", + "text_documents:title", + ], + query_string="Milvus", + top_k=3, + ).to_dict() + + # Verify that documents containing "Milvus" are returned + assert len(result_milvus["content"]) > 0 + assert any("Milvus" in content for content in result_milvus["content"]) + + # Test keyword search for "machine learning" + result_ml = store.retrieve_online_documents_v2( + features=[ + "text_documents:content", + "text_documents:title", + ], + query_string="machine learning", + top_k=3, + ).to_dict() + + # Verify that documents containing "machine learning" are returned + assert len(result_ml["content"]) > 0 + assert any("machine learning" in content.lower() for content in result_ml["content"]) + + # Test hybrid search (vector + keyword) + query_embedding = np.random.random(vector_length).tolist() + result_hybrid = store.retrieve_online_documents_v2( + features=[ + "text_documents:content", + "text_documents:title", + "text_documents:vector", + ], + query=query_embedding, + query_string="Feast", + top_k=3, + ).to_dict() + + # Verify hybrid search results + assert len(result_hybrid["content"]) > 0 + assert any("Feast" in content for content in result_hybrid["content"]) + assert len(result_hybrid["vector"]) > 0 From e2f3e28443b79ab04c971b01075994da9dde166f Mon Sep 17 00:00:00 2001 From: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> Date: Mon, 31 Mar 2025 01:05:23 +0200 Subject: [PATCH 3/4] fix linter Signed-off-by: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> Signed-off-by: yassinnouh21 --- .../milvus_online_store/milvus.py | 131 +++++++++--------- 1 file changed, 65 insertions(+), 66 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py index 38de3d32f52..9b2456a7443 100644 --- a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py +++ b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py @@ -114,8 +114,8 @@ class MilvusOnlineStore(OnlineStore): def _get_db_path(self, config: RepoConfig) -> str: assert ( - config.online_store.type == "milvus" - or config.online_store.type.endswith("MilvusOnlineStore") + config.online_store.type == "milvus" + or config.online_store.type.endswith("MilvusOnlineStore") ) if config.repo_path and not Path(config.online_store.path).is_absolute(): @@ -140,7 +140,7 @@ def _connect(self, config: RepoConfig) -> MilvusClient: return self.client def _get_or_create_collection( - self, config: RepoConfig, table: FeatureView + self, config: RepoConfig, table: FeatureView ) -> Dict[str, Any]: self.client = self._connect(config) vector_field_dict = {k.name: k for k in table.schema if k.vector_index} @@ -199,12 +199,12 @@ def _get_or_create_collection( index_params = self.client.prepare_index_params() for vector_field in schema.fields: if ( - vector_field.dtype - in [ - DataType.FLOAT_VECTOR, - DataType.BINARY_VECTOR, - ] - and vector_field.name in vector_field_dict + vector_field.dtype + in [ + DataType.FLOAT_VECTOR, + DataType.BINARY_VECTOR, + ] + and vector_field.name in vector_field_dict ): metric = vector_field_dict[ vector_field.name @@ -229,18 +229,18 @@ def _get_or_create_collection( return self._collections[collection_name] def online_write_batch( - self, - config: RepoConfig, - table: FeatureView, - data: List[ - Tuple[ - EntityKeyProto, - Dict[str, ValueProto], - datetime, - Optional[datetime], - ] - ], - progress: Optional[Callable[[int], Any]], + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[ + EntityKeyProto, + Dict[str, ValueProto], + datetime, + Optional[datetime], + ] + ], + progress: Optional[Callable[[int], Any]], ) -> None: self.client = self._connect(config) collection = self._get_or_create_collection(config, table) @@ -287,8 +287,8 @@ def online_write_batch( single_entity_record[field] = "" # Store only the latest event timestamp per entity if ( - entity_key_str not in unique_entities - or unique_entities[entity_key_str]["event_ts"] < timestamp_int + entity_key_str not in unique_entities + or unique_entities[entity_key_str]["event_ts"] < timestamp_int ): unique_entities[entity_key_str] = single_entity_record @@ -302,12 +302,12 @@ def online_write_batch( ) def online_read( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, - full_feature_names: bool = False, + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, + full_feature_names: bool = False, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: self.client = self._connect(config) collection_name = _table_id(config.project, table) @@ -316,9 +316,9 @@ def online_read( composite_key_name = _get_composite_key_name(table) output_fields = ( - [composite_key_name] - + (requested_features if requested_features else []) - + ["created_ts", "event_ts"] + [composite_key_name] + + (requested_features if requested_features else []) + + ["created_ts", "event_ts"] ) assert all( field in [f["name"] for f in collection["fields"]] @@ -335,9 +335,9 @@ def online_read( composite_entities.append(entity_key_str) query_filter_for_entities = ( - f"{composite_key_name} in [" - + ", ".join([f"'{e}'" for e in composite_entities]) - + "]" + f"{composite_key_name} in [" + + ", ".join([f"'{e}'" for e in composite_entities]) + + "]" ) self.client.load_collection(collection_name) results = self.client.query( @@ -441,13 +441,13 @@ def online_read( return result_list def update( - self, - config: RepoConfig, - tables_to_delete: Sequence[FeatureView], - tables_to_keep: Sequence[FeatureView], - entities_to_delete: Sequence[Entity], - entities_to_keep: Sequence[Entity], - partial: bool, + self, + config: RepoConfig, + tables_to_delete: Sequence[FeatureView], + tables_to_keep: Sequence[FeatureView], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + partial: bool, ): self.client = self._connect(config) for table in tables_to_keep: @@ -460,15 +460,15 @@ def update( self._collections.pop(collection_name, None) def plan( - self, config: RepoConfig, desired_registry_proto: RegistryProto + self, config: RepoConfig, desired_registry_proto: RegistryProto ) -> List[InfraObject]: raise NotImplementedError def teardown( - self, - config: RepoConfig, - tables: Sequence[FeatureView], - entities: Sequence[Entity], + self, + config: RepoConfig, + tables: Sequence[FeatureView], + entities: Sequence[Entity], ): self.client = self._connect(config) for table in tables: @@ -478,14 +478,14 @@ def teardown( self._collections.pop(collection_name, None) def retrieve_online_documents_v2( - self, - config: RepoConfig, - table: FeatureView, - requested_features: List[str], - embedding: Optional[List[float]], - top_k: int, - distance_metric: Optional[str] = None, - query_string: Optional[str] = None, + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + embedding: Optional[List[float]], + top_k: int, + distance_metric: Optional[str] = None, + query_string: Optional[str] = None, ) -> List[ Tuple[ Optional[datetime], @@ -514,7 +514,6 @@ def retrieve_online_documents_v2( self.client = self._connect(config) collection_name = _table_id(config.project, table) collection = self._get_or_create_collection(config, table) - if not config.online_store.vector_enabled: raise ValueError("Vector search is not enabled in the online store config") @@ -524,9 +523,9 @@ def retrieve_online_documents_v2( composite_key_name = _get_composite_key_name(table) output_fields = ( - [composite_key_name] - + (requested_features if requested_features else []) - + ["created_ts", "event_ts"] + [composite_key_name] + + (requested_features if requested_features else []) + + ["created_ts", "event_ts"] ) assert all( field in [f["name"] for f in collection["fields"]] @@ -656,14 +655,14 @@ def retrieve_online_documents_v2( ) res[ann_search_field] = serialized_embedding elif entity_name_feast_primitive_type_map.get( - field, PrimitiveFeastType.INVALID + field, PrimitiveFeastType.INVALID ) in [ PrimitiveFeastType.STRING, PrimitiveFeastType.BYTES, ]: res[field] = ValueProto(string_val=str(field_value)) elif entity_name_feast_primitive_type_map.get( - field, PrimitiveFeastType.INVALID + field, PrimitiveFeastType.INVALID ) in [ PrimitiveFeastType.INT64, PrimitiveFeastType.INT32, @@ -694,9 +693,9 @@ def _get_composite_key_name(table: FeatureView) -> str: def _extract_proto_values_to_dict( - input_dict: Dict[str, Any], - vector_cols: List[str], - serialize_to_string=False, + input_dict: Dict[str, Any], + vector_cols: List[str], + serialize_to_string=False, ) -> Dict[str, Any]: numeric_vector_list_types = [ k @@ -724,8 +723,8 @@ def _extract_proto_values_to_dict( vector_values = getattr(feature_values, proto_val_type).val else: if ( - serialize_to_string - and proto_val_type not in ["string_val"] + numeric_types + serialize_to_string + and proto_val_type not in ["string_val"] + numeric_types ): vector_values = feature_values.SerializeToString().decode() else: From 79548cc7cd76675fd50d1f11d71a3fb913dda01e Mon Sep 17 00:00:00 2001 From: yassinnouh21 Date: Mon, 31 Mar 2025 04:44:05 +0200 Subject: [PATCH 4/4] fix linter 2 Signed-off-by: yassinnouh21 --- .../milvus_online_store/milvus.py | 40 +++++++++++++------ .../online_store/test_online_retrieval.py | 36 +++++++++-------- 2 files changed, 46 insertions(+), 30 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py index 9b2456a7443..d39a2e3a16c 100644 --- a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py +++ b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py @@ -495,7 +495,6 @@ def retrieve_online_documents_v2( ]: """ Retrieve documents using vector similarity search or keyword search in Milvus. - Args: config: Feast configuration object table: FeatureView object as the table to search @@ -504,7 +503,6 @@ def retrieve_online_documents_v2( top_k: Number of items to return distance_metric: Distance metric to use (optional) query_string: The query string to search for using keyword search (optional) - Returns: List of tuples containing the event timestamp, entity key, and feature values """ @@ -539,22 +537,30 @@ def retrieve_online_documents_v2( if embedding is not None: for field in collection["fields"]: if ( - field["type"] in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR] - and field["name"] in output_fields + field["type"] in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR] + and field["name"] in output_fields ): ann_search_field = field["name"] break self.client.load_collection(collection_name) - if embedding is not None and query_string is not None and config.online_store.vector_enabled: + if ( + embedding is not None + and query_string is not None + and config.online_store.vector_enabled + ): string_field_list = [ - f.name for f in table.features if - isinstance(f.dtype, PrimitiveFeastType) and f.dtype.to_value_type() == ValueType.STRING + f.name + for f in table.features + if isinstance(f.dtype, PrimitiveFeastType) + and f.dtype.to_value_type() == ValueType.STRING ] if not string_field_list: - raise ValueError("No string fields found in the feature view for text search in hybrid mode") + raise ValueError( + "No string fields found in the feature view for text search in hybrid mode" + ) # Create a filter expression for text search filter_expressions = [] @@ -600,12 +606,16 @@ def retrieve_online_documents_v2( elif query_string is not None: string_field_list = [ - f.name for f in table.features if - isinstance(f.dtype, PrimitiveFeastType) and f.dtype.to_value_type() == ValueType.STRING + f.name + for f in table.features + if isinstance(f.dtype, PrimitiveFeastType) + and f.dtype.to_value_type() == ValueType.STRING ] if not string_field_list: - raise ValueError("No string fields found in the feature view for text search") + raise ValueError( + "No string fields found in the feature view for text search" + ) filter_expressions = [] for field in string_field_list: @@ -615,7 +625,9 @@ def retrieve_online_documents_v2( filter_expr = " OR ".join(filter_expressions) if not filter_expr: - raise ValueError("No text fields found in requested features for search") + raise ValueError( + "No text fields found in requested features for search" + ) query_results = self.client.query( collection_name=collection_name, @@ -624,7 +636,9 @@ def retrieve_online_documents_v2( limit=top_k, ) - results = [[{"entity": entity, "distance": -1.0}] for entity in query_results] + results = [ + [{"entity": entity, "distance": -1.0}] for entity in query_results + ] else: raise ValueError( "Either vector_enabled must be True for embedding search or query_string must be provided for keyword search" diff --git a/sdk/python/tests/unit/online_store/test_online_retrieval.py b/sdk/python/tests/unit/online_store/test_online_retrieval.py index eb232518acf..409a729ceee 100644 --- a/sdk/python/tests/unit/online_store/test_online_retrieval.py +++ b/sdk/python/tests/unit/online_store/test_online_retrieval.py @@ -1502,16 +1502,16 @@ def test_milvus_keyword_search() -> None: teardown=False, ) as store: from datetime import timedelta - + from feast import Entity, FeatureView, Field, FileSource from feast.types import Array, Float32, Int64, String, UnixTimestamp - + rag_documents_source = FileSource( path="data/embedded_documents.parquet", timestamp_field="event_timestamp", created_timestamp_column="created_timestamp", ) - + item = Entity( name="item_id", join_keys=["item_id"], @@ -1522,7 +1522,7 @@ def test_milvus_keyword_search() -> None: join_keys=["author_id"], value_type=ValueType.STRING, ) - + document_embeddings = FeatureView( name="text_documents", entities=[item, author], @@ -1543,13 +1543,13 @@ def test_milvus_keyword_search() -> None: source=rag_documents_source, ttl=timedelta(hours=24), ) - + store.apply([rag_documents_source, item, document_embeddings]) - + # Write some data with specific text content for keyword search document_embeddings_fv = store.get_feature_view(name="text_documents") provider = store._get_provider() - + contents = [ "Feast is an open source feature store for machine learning", "Feature stores solve the problem of coordinating features for training and serving", @@ -1562,7 +1562,7 @@ def test_milvus_keyword_search() -> None: "Offline stores are used for batch feature retrieval during training", "Feast enables data scientists to define, manage, and share features", ] - + titles = [ "Introduction to Feast", "Feature Store Benefits", @@ -1575,7 +1575,7 @@ def test_milvus_keyword_search() -> None: "Offline Training Support", "Feast for Data Scientists", ] - + item_keys = [ EntityKeyProto( join_keys=["item_id", "author_id"], @@ -1604,14 +1604,14 @@ def test_milvus_keyword_search() -> None: _utc_now(), ) ) - + provider.online_write_batch( config=store.config, table=document_embeddings_fv, data=data, progress=None, ) - + # Test keyword search for "Milvus" result_milvus = store.retrieve_online_documents_v2( features=[ @@ -1621,11 +1621,11 @@ def test_milvus_keyword_search() -> None: query_string="Milvus", top_k=3, ).to_dict() - + # Verify that documents containing "Milvus" are returned assert len(result_milvus["content"]) > 0 assert any("Milvus" in content for content in result_milvus["content"]) - + # Test keyword search for "machine learning" result_ml = store.retrieve_online_documents_v2( features=[ @@ -1635,11 +1635,13 @@ def test_milvus_keyword_search() -> None: query_string="machine learning", top_k=3, ).to_dict() - + # Verify that documents containing "machine learning" are returned assert len(result_ml["content"]) > 0 - assert any("machine learning" in content.lower() for content in result_ml["content"]) - + assert any( + "machine learning" in content.lower() for content in result_ml["content"] + ) + # Test hybrid search (vector + keyword) query_embedding = np.random.random(vector_length).tolist() result_hybrid = store.retrieve_online_documents_v2( @@ -1652,7 +1654,7 @@ def test_milvus_keyword_search() -> None: query_string="Feast", top_k=3, ).to_dict() - + # Verify hybrid search results assert len(result_hybrid["content"]) > 0 assert any("Feast" in content for content in result_hybrid["content"])