Skip to content

Commit c2af9a7

Browse files
YassinNouh21jfw-ppi
authored andcommitted
feat: Add retrieve online documents v2 method into pgvector (feast-dev#5253)
* feat: add online document retrieval with hybrid search capabilities Signed-off-by: yassinnouh21 <yassinnouh21@gmail.com> * test: add integration tests for hybrid search and document retrieval Signed-off-by: yassinnouh21 <yassinnouh21@gmail.com> * fix formatting Signed-off-by: yassinnouh21 <yassinnouh21@gmail.com> * fix: Refactor string_fields assignment to filter features by dtype and requested features Signed-off-by: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> * fix: improve query execution logic in postgres.py Signed-off-by: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> * fix linter Signed-off-by: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> * fix: simplify sorting logic in query execution Signed-off-by: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> * fix formatting Signed-off-by: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> * fix: update string feature check to use ValueType enumeration Signed-off-by: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> * formatting Signed-off-by: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> * fix datetime Signed-off-by: yassinnouh21 <yassinnouh21@gmail.com> --------- Signed-off-by: yassinnouh21 <yassinnouh21@gmail.com> Signed-off-by: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> Signed-off-by: Jacob Weinhold <29459386+j-wine@users.noreply.github.com>
1 parent 814afd7 commit c2af9a7

File tree

3 files changed

+472
-13
lines changed

3 files changed

+472
-13
lines changed

sdk/python/feast/feature_store.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2097,15 +2097,34 @@ def _retrieve_from_online_store_v2(
20972097
entity_key_dict[key] = []
20982098
entity_key_dict[key].append(python_value)
20992099

2100-
table_entity_values, idxs, output_len = utils._get_unique_entities_from_values(
2101-
entity_key_dict,
2102-
)
2103-
21042100
features_to_request: List[str] = []
21052101
if requested_features:
21062102
features_to_request = requested_features + ["distance"]
2103+
# Add text_rank for text search queries
2104+
if query_string is not None:
2105+
features_to_request.append("text_rank")
21072106
else:
21082107
features_to_request = ["distance"]
2108+
# Add text_rank for text search queries
2109+
if query_string is not None:
2110+
features_to_request.append("text_rank")
2111+
2112+
if not datevals:
2113+
online_features_response = GetOnlineFeaturesResponse(results=[])
2114+
for feature in features_to_request:
2115+
field = online_features_response.results.add()
2116+
field.values.extend([])
2117+
field.statuses.extend([])
2118+
field.event_timestamps.extend([])
2119+
online_features_response.metadata.feature_names.val.extend(
2120+
features_to_request
2121+
)
2122+
return OnlineResponse(online_features_response)
2123+
2124+
table_entity_values, idxs, output_len = utils._get_unique_entities_from_values(
2125+
entity_key_dict,
2126+
)
2127+
21092128
feature_data = utils._convert_rows_to_protobuf(
21102129
requested_features=features_to_request,
21112130
read_rows=list(zip(datevals, list_of_feature_dicts)),

sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py

Lines changed: 275 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
from psycopg.connection import Connection
2121
from psycopg_pool import AsyncConnectionPool, ConnectionPool
2222

23-
from feast import Entity
24-
from feast.feature_view import FeatureView
23+
from feast import Entity, FeatureView, ValueType
2524
from feast.infra.key_encoding_utils import get_list_val_str, serialize_entity_key
2625
from feast.infra.online_stores.helpers import _to_naive_utc
2726
from feast.infra.online_stores.online_store import OnlineStore
@@ -119,13 +118,20 @@ def online_write_batch(
119118

120119
for feature_name, val in values.items():
121120
vector_val = None
121+
value_text = None
122+
123+
# Check if the feature type is STRING
124+
if val.WhichOneof("val") == "string_val":
125+
value_text = val.string_val
126+
122127
if config.online_store.vector_enabled:
123128
vector_val = get_list_val_str(val)
124129
insert_values.append(
125130
(
126131
entity_key_bin,
127132
feature_name,
128133
val.SerializeToString(),
134+
value_text,
129135
vector_val,
130136
timestamp,
131137
created_ts,
@@ -136,11 +142,12 @@ def online_write_batch(
136142
sql_query = sql.SQL(
137143
"""
138144
INSERT INTO {}
139-
(entity_key, feature_name, value, vector_value, event_ts, created_ts)
140-
VALUES (%s, %s, %s, %s, %s, %s)
145+
(entity_key, feature_name, value, value_text, vector_value, event_ts, created_ts)
146+
VALUES (%s, %s, %s, %s, %s, %s, %s)
141147
ON CONFLICT (entity_key, feature_name) DO
142148
UPDATE SET
143149
value = EXCLUDED.value,
150+
value_text = EXCLUDED.value_text,
144151
vector_value = EXCLUDED.vector_value,
145152
event_ts = EXCLUDED.event_ts,
146153
created_ts = EXCLUDED.created_ts;
@@ -308,6 +315,11 @@ def update(
308315
else:
309316
# keep the vector_value_type as BYTEA if pgvector is not enabled, to maintain compatibility
310317
vector_value_type = "BYTEA"
318+
319+
has_string_features = any(
320+
f.dtype.to_value_type() == ValueType.STRING for f in table.features
321+
)
322+
311323
cur.execute(
312324
sql.SQL(
313325
"""
@@ -316,6 +328,7 @@ def update(
316328
entity_key BYTEA,
317329
feature_name TEXT,
318330
value BYTEA,
331+
value_text TEXT NULL, -- Added for FTS
319332
vector_value {} NULL,
320333
event_ts TIMESTAMPTZ,
321334
created_ts TIMESTAMPTZ,
@@ -331,6 +344,16 @@ def update(
331344
)
332345
)
333346

347+
if has_string_features:
348+
cur.execute(
349+
sql.SQL(
350+
"""CREATE INDEX IF NOT EXISTS {} ON {} USING GIN (to_tsvector('english', value_text));"""
351+
).format(
352+
sql.Identifier(f"{table_name}_fts_idx"),
353+
sql.Identifier(table_name),
354+
)
355+
)
356+
334357
conn.commit()
335358

336359
def teardown(
@@ -456,6 +479,254 @@ def retrieve_online_documents(
456479

457480
return result
458481

482+
def retrieve_online_documents_v2(
483+
self,
484+
config: RepoConfig,
485+
table: FeatureView,
486+
requested_features: List[str],
487+
embedding: Optional[List[float]],
488+
top_k: int,
489+
distance_metric: Optional[str] = None,
490+
query_string: Optional[str] = None,
491+
) -> List[
492+
Tuple[
493+
Optional[datetime],
494+
Optional[EntityKeyProto],
495+
Optional[Dict[str, ValueProto]],
496+
]
497+
]:
498+
"""
499+
Retrieve documents using vector similarity search or keyword search in PostgreSQL.
500+
501+
Args:
502+
config: Feast configuration object
503+
table: FeatureView object as the table to search
504+
requested_features: List of requested features to retrieve
505+
embedding: Query embedding to search for (optional)
506+
top_k: Number of items to return
507+
distance_metric: Distance metric to use (optional)
508+
query_string: The query string to search for using keyword search (optional)
509+
510+
Returns:
511+
List of tuples containing the event timestamp, entity key, and feature values
512+
"""
513+
if not config.online_store.vector_enabled:
514+
raise ValueError("Vector search is not enabled in the online store config")
515+
516+
if embedding is None and query_string is None:
517+
raise ValueError("Either embedding or query_string must be provided")
518+
519+
distance_metric = distance_metric or "L2"
520+
521+
if distance_metric not in SUPPORTED_DISTANCE_METRICS_DICT:
522+
raise ValueError(
523+
f"Distance metric {distance_metric} is not supported. Supported distance metrics are {SUPPORTED_DISTANCE_METRICS_DICT.keys()}"
524+
)
525+
526+
distance_metric_sql = SUPPORTED_DISTANCE_METRICS_DICT[distance_metric]
527+
528+
string_fields = [
529+
feature.name
530+
for feature in table.features
531+
if feature.dtype.to_value_type().value == 2
532+
and feature.name in requested_features
533+
]
534+
535+
table_name = _table_id(config.project, table)
536+
537+
with self._get_conn(config, autocommit=True) as conn, conn.cursor() as cur:
538+
query = None
539+
params: Any = None
540+
541+
if embedding is not None and query_string is not None and string_fields:
542+
# Case 1: Hybrid Search (vector + text)
543+
tsquery_str = " & ".join(query_string.split())
544+
query = sql.SQL(
545+
"""
546+
SELECT
547+
entity_key,
548+
feature_name,
549+
value,
550+
vector_value,
551+
vector_value {distance_metric_sql} %s::vector as distance,
552+
ts_rank(to_tsvector('english', value_text), to_tsquery('english', %s)) as text_rank,
553+
event_ts,
554+
created_ts
555+
FROM {table_name}
556+
WHERE feature_name = ANY(%s) AND to_tsvector('english', value_text) @@ to_tsquery('english', %s)
557+
ORDER BY distance
558+
LIMIT {top_k}
559+
"""
560+
).format(
561+
distance_metric_sql=sql.SQL(distance_metric_sql),
562+
table_name=sql.Identifier(table_name),
563+
top_k=sql.Literal(top_k),
564+
)
565+
params = (embedding, tsquery_str, string_fields, tsquery_str)
566+
567+
elif embedding is not None:
568+
# Case 2: Vector Search Only
569+
query = sql.SQL(
570+
"""
571+
SELECT
572+
entity_key,
573+
feature_name,
574+
value,
575+
vector_value,
576+
vector_value {distance_metric_sql} %s::vector as distance,
577+
NULL as text_rank, -- Keep consistent columns
578+
event_ts,
579+
created_ts
580+
FROM {table_name}
581+
ORDER BY distance
582+
LIMIT {top_k}
583+
"""
584+
).format(
585+
distance_metric_sql=sql.SQL(distance_metric_sql),
586+
table_name=sql.Identifier(table_name),
587+
top_k=sql.Literal(top_k),
588+
)
589+
params = (embedding,)
590+
591+
elif query_string is not None and string_fields:
592+
# Case 3: Text Search Only
593+
tsquery_str = " & ".join(query_string.split())
594+
query = sql.SQL(
595+
"""
596+
WITH text_matches AS (
597+
SELECT DISTINCT entity_key, ts_rank(to_tsvector('english', value_text), to_tsquery('english', %s)) as text_rank
598+
FROM {table_name}
599+
WHERE feature_name = ANY(%s) AND to_tsvector('english', value_text) @@ to_tsquery('english', %s)
600+
ORDER BY text_rank DESC
601+
LIMIT {top_k}
602+
)
603+
SELECT
604+
t1.entity_key,
605+
t1.feature_name,
606+
t1.value,
607+
t1.vector_value,
608+
NULL as distance,
609+
t2.text_rank,
610+
t1.event_ts,
611+
t1.created_ts
612+
FROM {table_name} t1
613+
INNER JOIN text_matches t2 ON t1.entity_key = t2.entity_key
614+
WHERE t1.feature_name = ANY(%s)
615+
ORDER BY t2.text_rank DESC
616+
"""
617+
).format(
618+
table_name=sql.Identifier(table_name),
619+
top_k=sql.Literal(top_k),
620+
)
621+
params = (tsquery_str, string_fields, tsquery_str, requested_features)
622+
623+
else:
624+
raise ValueError(
625+
"Either vector_enabled must be True for embedding search or string fields must be available for query_string search"
626+
)
627+
628+
cur.execute(query, params)
629+
rows = cur.fetchall()
630+
631+
# Group by entity_key to build feature records
632+
entities_dict: Dict[str, Dict[str, Any]] = defaultdict(
633+
lambda: {
634+
"features": {},
635+
"timestamp": None,
636+
"entity_key_proto": None,
637+
"vector_distance": float("inf"),
638+
"text_rank": 0.0,
639+
}
640+
)
641+
642+
for (
643+
entity_key_bytes,
644+
feature_name,
645+
feature_val_bytes,
646+
vector_val,
647+
distance,
648+
text_rank,
649+
event_ts,
650+
created_ts,
651+
) in rows:
652+
entity_key_proto = None
653+
if entity_key_bytes:
654+
from feast.infra.key_encoding_utils import deserialize_entity_key
655+
656+
entity_key_proto = deserialize_entity_key(entity_key_bytes)
657+
658+
key = entity_key_bytes.hex() if entity_key_bytes else None
659+
660+
if key is None:
661+
continue
662+
663+
entities_dict[key]["entity_key_proto"] = entity_key_proto
664+
665+
if (
666+
entities_dict[key]["timestamp"] is None
667+
or event_ts > entities_dict[key]["timestamp"]
668+
):
669+
entities_dict[key]["timestamp"] = event_ts
670+
671+
val = ValueProto()
672+
if feature_val_bytes:
673+
val.ParseFromString(feature_val_bytes)
674+
675+
entities_dict[key]["features"][feature_name] = val
676+
677+
if distance is not None:
678+
entities_dict[key]["vector_distance"] = min(
679+
entities_dict[key]["vector_distance"], float(distance)
680+
)
681+
if text_rank is not None:
682+
entities_dict[key]["text_rank"] = max(
683+
entities_dict[key]["text_rank"], float(text_rank)
684+
)
685+
686+
sorted_entities = sorted(
687+
entities_dict.values(),
688+
key=lambda x: x["vector_distance"]
689+
if embedding is not None
690+
else x["text_rank"],
691+
reverse=(embedding is None),
692+
)[:top_k]
693+
694+
result: List[
695+
Tuple[
696+
Optional[datetime],
697+
Optional[EntityKeyProto],
698+
Optional[Dict[str, ValueProto]],
699+
]
700+
] = []
701+
for entity_data in sorted_entities:
702+
features = (
703+
entity_data["features"].copy()
704+
if isinstance(entity_data["features"], dict)
705+
else None
706+
)
707+
708+
if features is not None:
709+
if "vector_distance" in entity_data and entity_data[
710+
"vector_distance"
711+
] != float("inf"):
712+
dist_val = ValueProto()
713+
dist_val.double_val = entity_data["vector_distance"]
714+
features["distance"] = dist_val
715+
716+
if embedding is None or query_string is not None:
717+
rank_val = ValueProto()
718+
rank_val.double_val = entity_data["text_rank"]
719+
features["text_rank"] = rank_val
720+
721+
result.append(
722+
(
723+
entity_data["timestamp"],
724+
entity_data["entity_key_proto"],
725+
features,
726+
)
727+
)
728+
return result
729+
459730

460731
def _table_id(project: str, table: FeatureView) -> str:
461732
return f"{project}_{table.name}"

0 commit comments

Comments
 (0)