diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 1578a91574e..bd6f1873874 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -2097,15 +2097,34 @@ def _retrieve_from_online_store_v2( entity_key_dict[key] = [] entity_key_dict[key].append(python_value) - table_entity_values, idxs, output_len = utils._get_unique_entities_from_values( - entity_key_dict, - ) - features_to_request: List[str] = [] if requested_features: features_to_request = requested_features + ["distance"] + # Add text_rank for text search queries + if query_string is not None: + features_to_request.append("text_rank") else: features_to_request = ["distance"] + # Add text_rank for text search queries + if query_string is not None: + features_to_request.append("text_rank") + + if not datevals: + online_features_response = GetOnlineFeaturesResponse(results=[]) + for feature in features_to_request: + field = online_features_response.results.add() + field.values.extend([]) + field.statuses.extend([]) + field.event_timestamps.extend([]) + online_features_response.metadata.feature_names.val.extend( + features_to_request + ) + return OnlineResponse(online_features_response) + + table_entity_values, idxs, output_len = utils._get_unique_entities_from_values( + entity_key_dict, + ) + feature_data = utils._convert_rows_to_protobuf( requested_features=features_to_request, read_rows=list(zip(datevals, list_of_feature_dicts)), diff --git a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py index b5c1dd05f3a..daa128264df 100644 --- a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py +++ b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py @@ -20,8 +20,7 @@ from psycopg.connection import Connection from psycopg_pool import AsyncConnectionPool, ConnectionPool -from feast import Entity -from feast.feature_view import FeatureView +from feast import Entity, FeatureView, ValueType from feast.infra.key_encoding_utils import get_list_val_str, serialize_entity_key from feast.infra.online_stores.helpers import _to_naive_utc from feast.infra.online_stores.online_store import OnlineStore @@ -119,6 +118,12 @@ def online_write_batch( for feature_name, val in values.items(): vector_val = None + value_text = None + + # Check if the feature type is STRING + if val.WhichOneof("val") == "string_val": + value_text = val.string_val + if config.online_store.vector_enabled: vector_val = get_list_val_str(val) insert_values.append( @@ -126,6 +131,7 @@ def online_write_batch( entity_key_bin, feature_name, val.SerializeToString(), + value_text, vector_val, timestamp, created_ts, @@ -136,11 +142,12 @@ def online_write_batch( sql_query = sql.SQL( """ INSERT INTO {} - (entity_key, feature_name, value, vector_value, event_ts, created_ts) - VALUES (%s, %s, %s, %s, %s, %s) + (entity_key, feature_name, value, value_text, vector_value, event_ts, created_ts) + VALUES (%s, %s, %s, %s, %s, %s, %s) ON CONFLICT (entity_key, feature_name) DO UPDATE SET value = EXCLUDED.value, + value_text = EXCLUDED.value_text, vector_value = EXCLUDED.vector_value, event_ts = EXCLUDED.event_ts, created_ts = EXCLUDED.created_ts; @@ -308,6 +315,11 @@ def update( else: # keep the vector_value_type as BYTEA if pgvector is not enabled, to maintain compatibility vector_value_type = "BYTEA" + + has_string_features = any( + f.dtype.to_value_type() == ValueType.STRING for f in table.features + ) + cur.execute( sql.SQL( """ @@ -316,6 +328,7 @@ def update( entity_key BYTEA, feature_name TEXT, value BYTEA, + value_text TEXT NULL, -- Added for FTS vector_value {} NULL, event_ts TIMESTAMPTZ, created_ts TIMESTAMPTZ, @@ -331,6 +344,16 @@ def update( ) ) + if has_string_features: + cur.execute( + sql.SQL( + """CREATE INDEX IF NOT EXISTS {} ON {} USING GIN (to_tsvector('english', value_text));""" + ).format( + sql.Identifier(f"{table_name}_fts_idx"), + sql.Identifier(table_name), + ) + ) + conn.commit() def teardown( @@ -456,6 +479,254 @@ def retrieve_online_documents( return result + def retrieve_online_documents_v2( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + embedding: Optional[List[float]], + top_k: int, + distance_metric: Optional[str] = None, + query_string: Optional[str] = None, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + """ + Retrieve documents using vector similarity search or keyword search in PostgreSQL. + + Args: + config: Feast configuration object + table: FeatureView object as the table to search + requested_features: List of requested features to retrieve + embedding: Query embedding to search for (optional) + top_k: Number of items to return + distance_metric: Distance metric to use (optional) + query_string: The query string to search for using keyword search (optional) + + Returns: + List of tuples containing the event timestamp, entity key, and feature values + """ + if not config.online_store.vector_enabled: + raise ValueError("Vector search is not enabled in the online store config") + + if embedding is None and query_string is None: + raise ValueError("Either embedding or query_string must be provided") + + distance_metric = distance_metric or "L2" + + if distance_metric not in SUPPORTED_DISTANCE_METRICS_DICT: + raise ValueError( + f"Distance metric {distance_metric} is not supported. Supported distance metrics are {SUPPORTED_DISTANCE_METRICS_DICT.keys()}" + ) + + distance_metric_sql = SUPPORTED_DISTANCE_METRICS_DICT[distance_metric] + + string_fields = [ + feature.name + for feature in table.features + if feature.dtype.to_value_type().value == 2 + and feature.name in requested_features + ] + + table_name = _table_id(config.project, table) + + with self._get_conn(config, autocommit=True) as conn, conn.cursor() as cur: + query = None + params: Any = None + + if embedding is not None and query_string is not None and string_fields: + # Case 1: Hybrid Search (vector + text) + tsquery_str = " & ".join(query_string.split()) + query = sql.SQL( + """ + SELECT + entity_key, + feature_name, + value, + vector_value, + vector_value {distance_metric_sql} %s::vector as distance, + ts_rank(to_tsvector('english', value_text), to_tsquery('english', %s)) as text_rank, + event_ts, + created_ts + FROM {table_name} + WHERE feature_name = ANY(%s) AND to_tsvector('english', value_text) @@ to_tsquery('english', %s) + ORDER BY distance + LIMIT {top_k} + """ + ).format( + distance_metric_sql=sql.SQL(distance_metric_sql), + table_name=sql.Identifier(table_name), + top_k=sql.Literal(top_k), + ) + params = (embedding, tsquery_str, string_fields, tsquery_str) + + elif embedding is not None: + # Case 2: Vector Search Only + query = sql.SQL( + """ + SELECT + entity_key, + feature_name, + value, + vector_value, + vector_value {distance_metric_sql} %s::vector as distance, + NULL as text_rank, -- Keep consistent columns + event_ts, + created_ts + FROM {table_name} + ORDER BY distance + LIMIT {top_k} + """ + ).format( + distance_metric_sql=sql.SQL(distance_metric_sql), + table_name=sql.Identifier(table_name), + top_k=sql.Literal(top_k), + ) + params = (embedding,) + + elif query_string is not None and string_fields: + # Case 3: Text Search Only + tsquery_str = " & ".join(query_string.split()) + query = sql.SQL( + """ + WITH text_matches AS ( + SELECT DISTINCT entity_key, ts_rank(to_tsvector('english', value_text), to_tsquery('english', %s)) as text_rank + FROM {table_name} + WHERE feature_name = ANY(%s) AND to_tsvector('english', value_text) @@ to_tsquery('english', %s) + ORDER BY text_rank DESC + LIMIT {top_k} + ) + SELECT + t1.entity_key, + t1.feature_name, + t1.value, + t1.vector_value, + NULL as distance, + t2.text_rank, + t1.event_ts, + t1.created_ts + FROM {table_name} t1 + INNER JOIN text_matches t2 ON t1.entity_key = t2.entity_key + WHERE t1.feature_name = ANY(%s) + ORDER BY t2.text_rank DESC + """ + ).format( + table_name=sql.Identifier(table_name), + top_k=sql.Literal(top_k), + ) + params = (tsquery_str, string_fields, tsquery_str, requested_features) + + else: + raise ValueError( + "Either vector_enabled must be True for embedding search or string fields must be available for query_string search" + ) + + cur.execute(query, params) + rows = cur.fetchall() + + # Group by entity_key to build feature records + entities_dict: Dict[str, Dict[str, Any]] = defaultdict( + lambda: { + "features": {}, + "timestamp": None, + "entity_key_proto": None, + "vector_distance": float("inf"), + "text_rank": 0.0, + } + ) + + for ( + entity_key_bytes, + feature_name, + feature_val_bytes, + vector_val, + distance, + text_rank, + event_ts, + created_ts, + ) in rows: + entity_key_proto = None + if entity_key_bytes: + from feast.infra.key_encoding_utils import deserialize_entity_key + + entity_key_proto = deserialize_entity_key(entity_key_bytes) + + key = entity_key_bytes.hex() if entity_key_bytes else None + + if key is None: + continue + + entities_dict[key]["entity_key_proto"] = entity_key_proto + + if ( + entities_dict[key]["timestamp"] is None + or event_ts > entities_dict[key]["timestamp"] + ): + entities_dict[key]["timestamp"] = event_ts + + val = ValueProto() + if feature_val_bytes: + val.ParseFromString(feature_val_bytes) + + entities_dict[key]["features"][feature_name] = val + + if distance is not None: + entities_dict[key]["vector_distance"] = min( + entities_dict[key]["vector_distance"], float(distance) + ) + if text_rank is not None: + entities_dict[key]["text_rank"] = max( + entities_dict[key]["text_rank"], float(text_rank) + ) + + sorted_entities = sorted( + entities_dict.values(), + key=lambda x: x["vector_distance"] + if embedding is not None + else x["text_rank"], + reverse=(embedding is None), + )[:top_k] + + result: List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ] = [] + for entity_data in sorted_entities: + features = ( + entity_data["features"].copy() + if isinstance(entity_data["features"], dict) + else None + ) + + if features is not None: + if "vector_distance" in entity_data and entity_data[ + "vector_distance" + ] != float("inf"): + dist_val = ValueProto() + dist_val.double_val = entity_data["vector_distance"] + features["distance"] = dist_val + + if embedding is None or query_string is not None: + rank_val = ValueProto() + rank_val.double_val = entity_data["text_rank"] + features["text_rank"] = rank_val + + result.append( + ( + entity_data["timestamp"], + entity_data["entity_key_proto"], + features, + ) + ) + return result + def _table_id(project: str, table: FeatureView) -> str: return f"{project}_{table.name}" diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index b563f00bfd1..523cf700d46 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -1,8 +1,8 @@ -import datetime import os +import random import time import unittest -from datetime import timedelta +from datetime import datetime, timedelta from typing import Any, Dict, List, Tuple, Union import assertpy @@ -18,9 +18,10 @@ from feast.feature_service import FeatureService from feast.feature_view import FeatureView from feast.field import Field +from feast.infra.offline_stores.file_source import FileSource from feast.infra.utils.postgres.postgres_config import ConnectionType from feast.online_response import TIMESTAMP_POSTFIX -from feast.types import Float32, Int32, String +from feast.types import Array, Float32, Int32, Int64, String, ValueType from feast.utils import _utc_now from feast.wait import wait_retry_backoff from tests.integration.feature_repos.repo_configuration import ( @@ -219,7 +220,7 @@ def test_write_to_online_store_event_check(environment): # writes to online store via datasource (dataframe_source) materialization fs.materialize( - start_date=datetime.datetime.now() - timedelta(hours=12), + start_date=datetime.now() - timedelta(hours=12), end_date=_utc_now(), ) @@ -861,7 +862,7 @@ def assert_feature_service_entity_mapping_correctness( @pytest.mark.integration -@pytest.mark.universal_online_stores(only=["pgvector", "elasticsearch", "qdrant"]) +@pytest.mark.universal_online_stores(only=["pgvector"]) def test_retrieve_online_documents(environment, fake_document_data): fs = environment.feature_store df, data_source = fake_document_data @@ -919,3 +920,171 @@ def test_retrieve_online_milvus_documents(environment, fake_document_data): assert len(documents["item_id"]) == 2 assert documents["item_id"] == [2, 3] + + +@pytest.mark.integration +@pytest.mark.universal_online_stores(only=["pgvector"]) +def test_postgres_retrieve_online_documents_v2(environment, fake_document_data): + """Test retrieval of documents using PostgreSQL vector store capabilities.""" + fs = environment.feature_store + + n_rows = 20 + vector_dim = 2 + random.seed(42) + + df = pd.DataFrame( + { + "item_id": list(range(n_rows)), + "embedding": [list(np.random.random(vector_dim)) for _ in range(n_rows)], + "text_field": [ + f"Document text content {i} with searchable keywords" + for i in range(n_rows) + ], + "category": [f"Category-{i % 5}" for i in range(n_rows)], + "event_timestamp": [datetime.now() for _ in range(n_rows)], + } + ) + + data_source = FileSource( + path="dummy_path.parquet", timestamp_field="event_timestamp" + ) + + item = Entity( + name="item_id", + join_keys=["item_id"], + value_type=ValueType.INT64, + ) + + item_embeddings_fv = FeatureView( + name="item_embeddings", + entities=[item], + schema=[ + Field(name="embedding", dtype=Array(Float32), vector_index=True), + Field(name="text_field", dtype=String), + Field(name="category", dtype=String), + Field(name="item_id", dtype=Int64), + ], + source=data_source, + ) + + fs.apply([item_embeddings_fv, item]) + fs.write_to_online_store("item_embeddings", df) + + # Test 1: Vector similarity search + query_embedding = list(np.random.random(vector_dim)) + vector_results = fs.retrieve_online_documents_v2( + features=[ + "item_embeddings:embedding", + "item_embeddings:text_field", + "item_embeddings:category", + "item_embeddings:item_id", + ], + query=query_embedding, + top_k=5, + distance_metric="L2", + ).to_dict() + + assert len(vector_results["embedding"]) == 5 + assert len(vector_results["distance"]) == 5 + assert len(vector_results["text_field"]) == 5 + assert len(vector_results["category"]) == 5 + + # Test 2: Vector similarity search with Cosine distance + vector_results = fs.retrieve_online_documents_v2( + features=[ + "item_embeddings:embedding", + "item_embeddings:text_field", + "item_embeddings:category", + "item_embeddings:item_id", + ], + query=query_embedding, + top_k=5, + distance_metric="cosine", + ).to_dict() + + assert len(vector_results["embedding"]) == 5 + assert len(vector_results["distance"]) == 5 + assert len(vector_results["text_field"]) == 5 + assert len(vector_results["category"]) == 5 + + # Test 3: Full text search + text_results = fs.retrieve_online_documents_v2( + features=[ + "item_embeddings:embedding", + "item_embeddings:text_field", + "item_embeddings:category", + "item_embeddings:item_id", + ], + query_string="searchable keywords", + top_k=5, + ).to_dict() + + # Verify text search results + assert len(text_results["text_field"]) == 5 + assert len(text_results["text_rank"]) == 5 + assert len(text_results["category"]) == 5 + assert len(text_results["item_id"]) == 5 + + # Verify text rank values are between 0 and 1 + assert all(0 <= rank <= 1 for rank in text_results["text_rank"]) + + # Verify results are sorted by text rank in descending order + text_ranks = text_results["text_rank"] + assert all(text_ranks[i] >= text_ranks[i + 1] for i in range(len(text_ranks) - 1)) + + # Test 4: Hybrid search (vector + text) + hybrid_results = fs.retrieve_online_documents_v2( + features=[ + "item_embeddings:embedding", + "item_embeddings:text_field", + "item_embeddings:category", + "item_embeddings:item_id", + ], + query=query_embedding, + query_string="searchable keywords", + top_k=5, + distance_metric="L2", + ).to_dict() + + # Verify hybrid search results + assert len(hybrid_results["embedding"]) == 5 + assert len(hybrid_results["distance"]) == 5 + assert len(hybrid_results["text_field"]) == 5 + assert len(hybrid_results["text_rank"]) == 5 + assert len(hybrid_results["category"]) == 5 + assert len(hybrid_results["item_id"]) == 5 + + # Test 5: Hybrid search with different text query + hybrid_results = fs.retrieve_online_documents_v2( + features=[ + "item_embeddings:embedding", + "item_embeddings:text_field", + "item_embeddings:category", + "item_embeddings:item_id", + ], + query=query_embedding, + query_string="Category-1", + top_k=5, + distance_metric="L2", + ).to_dict() + + # Verify results contain only documents from Category-1 + assert all(cat == "Category-1" for cat in hybrid_results["category"]) + + # Test 6: Full text search with no matches + no_match_results = fs.retrieve_online_documents_v2( + features=[ + "item_embeddings:embedding", + "item_embeddings:text_field", + "item_embeddings:category", + "item_embeddings:item_id", + ], + query_string="nonexistent keyword", + top_k=5, + ).to_dict() + + # Verify no results are returned for non-matching query + assert "text_field" in no_match_results + assert len(no_match_results["text_field"]) == 0 + assert "text_rank" in no_match_results + assert len(no_match_results["text_rank"]) == 0