2020from psycopg .connection import Connection
2121from psycopg_pool import AsyncConnectionPool , ConnectionPool
2222
23- from feast import Entity
24- from feast .feature_view import FeatureView
23+ from feast import Entity , FeatureView , ValueType
2524from feast .infra .key_encoding_utils import get_list_val_str , serialize_entity_key
2625from feast .infra .online_stores .helpers import _to_naive_utc
2726from feast .infra .online_stores .online_store import OnlineStore
@@ -119,13 +118,20 @@ def online_write_batch(
119118
120119 for feature_name , val in values .items ():
121120 vector_val = None
121+ value_text = None
122+
123+ # Check if the feature type is STRING
124+ if val .WhichOneof ("val" ) == "string_val" :
125+ value_text = val .string_val
126+
122127 if config .online_store .vector_enabled :
123128 vector_val = get_list_val_str (val )
124129 insert_values .append (
125130 (
126131 entity_key_bin ,
127132 feature_name ,
128133 val .SerializeToString (),
134+ value_text ,
129135 vector_val ,
130136 timestamp ,
131137 created_ts ,
@@ -136,11 +142,12 @@ def online_write_batch(
136142 sql_query = sql .SQL (
137143 """
138144 INSERT INTO {}
139- (entity_key, feature_name, value, vector_value, event_ts, created_ts)
140- VALUES (%s, %s, %s, %s, %s, %s)
145+ (entity_key, feature_name, value, value_text, vector_value, event_ts, created_ts)
146+ VALUES (%s, %s, %s, %s, %s, %s, %s )
141147 ON CONFLICT (entity_key, feature_name) DO
142148 UPDATE SET
143149 value = EXCLUDED.value,
150+ value_text = EXCLUDED.value_text,
144151 vector_value = EXCLUDED.vector_value,
145152 event_ts = EXCLUDED.event_ts,
146153 created_ts = EXCLUDED.created_ts;
@@ -308,6 +315,11 @@ def update(
308315 else :
309316 # keep the vector_value_type as BYTEA if pgvector is not enabled, to maintain compatibility
310317 vector_value_type = "BYTEA"
318+
319+ has_string_features = any (
320+ f .dtype .to_value_type () == ValueType .STRING for f in table .features
321+ )
322+
311323 cur .execute (
312324 sql .SQL (
313325 """
@@ -316,6 +328,7 @@ def update(
316328 entity_key BYTEA,
317329 feature_name TEXT,
318330 value BYTEA,
331+ value_text TEXT NULL, -- Added for FTS
319332 vector_value {} NULL,
320333 event_ts TIMESTAMPTZ,
321334 created_ts TIMESTAMPTZ,
@@ -331,6 +344,16 @@ def update(
331344 )
332345 )
333346
347+ if has_string_features :
348+ cur .execute (
349+ sql .SQL (
350+ """CREATE INDEX IF NOT EXISTS {} ON {} USING GIN (to_tsvector('english', value_text));"""
351+ ).format (
352+ sql .Identifier (f"{ table_name } _fts_idx" ),
353+ sql .Identifier (table_name ),
354+ )
355+ )
356+
334357 conn .commit ()
335358
336359 def teardown (
@@ -456,6 +479,254 @@ def retrieve_online_documents(
456479
457480 return result
458481
482+ def retrieve_online_documents_v2 (
483+ self ,
484+ config : RepoConfig ,
485+ table : FeatureView ,
486+ requested_features : List [str ],
487+ embedding : Optional [List [float ]],
488+ top_k : int ,
489+ distance_metric : Optional [str ] = None ,
490+ query_string : Optional [str ] = None ,
491+ ) -> List [
492+ Tuple [
493+ Optional [datetime ],
494+ Optional [EntityKeyProto ],
495+ Optional [Dict [str , ValueProto ]],
496+ ]
497+ ]:
498+ """
499+ Retrieve documents using vector similarity search or keyword search in PostgreSQL.
500+
501+ Args:
502+ config: Feast configuration object
503+ table: FeatureView object as the table to search
504+ requested_features: List of requested features to retrieve
505+ embedding: Query embedding to search for (optional)
506+ top_k: Number of items to return
507+ distance_metric: Distance metric to use (optional)
508+ query_string: The query string to search for using keyword search (optional)
509+
510+ Returns:
511+ List of tuples containing the event timestamp, entity key, and feature values
512+ """
513+ if not config .online_store .vector_enabled :
514+ raise ValueError ("Vector search is not enabled in the online store config" )
515+
516+ if embedding is None and query_string is None :
517+ raise ValueError ("Either embedding or query_string must be provided" )
518+
519+ distance_metric = distance_metric or "L2"
520+
521+ if distance_metric not in SUPPORTED_DISTANCE_METRICS_DICT :
522+ raise ValueError (
523+ f"Distance metric { distance_metric } is not supported. Supported distance metrics are { SUPPORTED_DISTANCE_METRICS_DICT .keys ()} "
524+ )
525+
526+ distance_metric_sql = SUPPORTED_DISTANCE_METRICS_DICT [distance_metric ]
527+
528+ string_fields = [
529+ feature .name
530+ for feature in table .features
531+ if feature .dtype .to_value_type ().value == 2
532+ and feature .name in requested_features
533+ ]
534+
535+ table_name = _table_id (config .project , table )
536+
537+ with self ._get_conn (config , autocommit = True ) as conn , conn .cursor () as cur :
538+ query = None
539+ params : Any = None
540+
541+ if embedding is not None and query_string is not None and string_fields :
542+ # Case 1: Hybrid Search (vector + text)
543+ tsquery_str = " & " .join (query_string .split ())
544+ query = sql .SQL (
545+ """
546+ SELECT
547+ entity_key,
548+ feature_name,
549+ value,
550+ vector_value,
551+ vector_value {distance_metric_sql} %s::vector as distance,
552+ ts_rank(to_tsvector('english', value_text), to_tsquery('english', %s)) as text_rank,
553+ event_ts,
554+ created_ts
555+ FROM {table_name}
556+ WHERE feature_name = ANY(%s) AND to_tsvector('english', value_text) @@ to_tsquery('english', %s)
557+ ORDER BY distance
558+ LIMIT {top_k}
559+ """
560+ ).format (
561+ distance_metric_sql = sql .SQL (distance_metric_sql ),
562+ table_name = sql .Identifier (table_name ),
563+ top_k = sql .Literal (top_k ),
564+ )
565+ params = (embedding , tsquery_str , string_fields , tsquery_str )
566+
567+ elif embedding is not None :
568+ # Case 2: Vector Search Only
569+ query = sql .SQL (
570+ """
571+ SELECT
572+ entity_key,
573+ feature_name,
574+ value,
575+ vector_value,
576+ vector_value {distance_metric_sql} %s::vector as distance,
577+ NULL as text_rank, -- Keep consistent columns
578+ event_ts,
579+ created_ts
580+ FROM {table_name}
581+ ORDER BY distance
582+ LIMIT {top_k}
583+ """
584+ ).format (
585+ distance_metric_sql = sql .SQL (distance_metric_sql ),
586+ table_name = sql .Identifier (table_name ),
587+ top_k = sql .Literal (top_k ),
588+ )
589+ params = (embedding ,)
590+
591+ elif query_string is not None and string_fields :
592+ # Case 3: Text Search Only
593+ tsquery_str = " & " .join (query_string .split ())
594+ query = sql .SQL (
595+ """
596+ WITH text_matches AS (
597+ SELECT DISTINCT entity_key, ts_rank(to_tsvector('english', value_text), to_tsquery('english', %s)) as text_rank
598+ FROM {table_name}
599+ WHERE feature_name = ANY(%s) AND to_tsvector('english', value_text) @@ to_tsquery('english', %s)
600+ ORDER BY text_rank DESC
601+ LIMIT {top_k}
602+ )
603+ SELECT
604+ t1.entity_key,
605+ t1.feature_name,
606+ t1.value,
607+ t1.vector_value,
608+ NULL as distance,
609+ t2.text_rank,
610+ t1.event_ts,
611+ t1.created_ts
612+ FROM {table_name} t1
613+ INNER JOIN text_matches t2 ON t1.entity_key = t2.entity_key
614+ WHERE t1.feature_name = ANY(%s)
615+ ORDER BY t2.text_rank DESC
616+ """
617+ ).format (
618+ table_name = sql .Identifier (table_name ),
619+ top_k = sql .Literal (top_k ),
620+ )
621+ params = (tsquery_str , string_fields , tsquery_str , requested_features )
622+
623+ else :
624+ raise ValueError (
625+ "Either vector_enabled must be True for embedding search or string fields must be available for query_string search"
626+ )
627+
628+ cur .execute (query , params )
629+ rows = cur .fetchall ()
630+
631+ # Group by entity_key to build feature records
632+ entities_dict : Dict [str , Dict [str , Any ]] = defaultdict (
633+ lambda : {
634+ "features" : {},
635+ "timestamp" : None ,
636+ "entity_key_proto" : None ,
637+ "vector_distance" : float ("inf" ),
638+ "text_rank" : 0.0 ,
639+ }
640+ )
641+
642+ for (
643+ entity_key_bytes ,
644+ feature_name ,
645+ feature_val_bytes ,
646+ vector_val ,
647+ distance ,
648+ text_rank ,
649+ event_ts ,
650+ created_ts ,
651+ ) in rows :
652+ entity_key_proto = None
653+ if entity_key_bytes :
654+ from feast .infra .key_encoding_utils import deserialize_entity_key
655+
656+ entity_key_proto = deserialize_entity_key (entity_key_bytes )
657+
658+ key = entity_key_bytes .hex () if entity_key_bytes else None
659+
660+ if key is None :
661+ continue
662+
663+ entities_dict [key ]["entity_key_proto" ] = entity_key_proto
664+
665+ if (
666+ entities_dict [key ]["timestamp" ] is None
667+ or event_ts > entities_dict [key ]["timestamp" ]
668+ ):
669+ entities_dict [key ]["timestamp" ] = event_ts
670+
671+ val = ValueProto ()
672+ if feature_val_bytes :
673+ val .ParseFromString (feature_val_bytes )
674+
675+ entities_dict [key ]["features" ][feature_name ] = val
676+
677+ if distance is not None :
678+ entities_dict [key ]["vector_distance" ] = min (
679+ entities_dict [key ]["vector_distance" ], float (distance )
680+ )
681+ if text_rank is not None :
682+ entities_dict [key ]["text_rank" ] = max (
683+ entities_dict [key ]["text_rank" ], float (text_rank )
684+ )
685+
686+ sorted_entities = sorted (
687+ entities_dict .values (),
688+ key = lambda x : x ["vector_distance" ]
689+ if embedding is not None
690+ else x ["text_rank" ],
691+ reverse = (embedding is None ),
692+ )[:top_k ]
693+
694+ result : List [
695+ Tuple [
696+ Optional [datetime ],
697+ Optional [EntityKeyProto ],
698+ Optional [Dict [str , ValueProto ]],
699+ ]
700+ ] = []
701+ for entity_data in sorted_entities :
702+ features = (
703+ entity_data ["features" ].copy ()
704+ if isinstance (entity_data ["features" ], dict )
705+ else None
706+ )
707+
708+ if features is not None :
709+ if "vector_distance" in entity_data and entity_data [
710+ "vector_distance"
711+ ] != float ("inf" ):
712+ dist_val = ValueProto ()
713+ dist_val .double_val = entity_data ["vector_distance" ]
714+ features ["distance" ] = dist_val
715+
716+ if embedding is None or query_string is not None :
717+ rank_val = ValueProto ()
718+ rank_val .double_val = entity_data ["text_rank" ]
719+ features ["text_rank" ] = rank_val
720+
721+ result .append (
722+ (
723+ entity_data ["timestamp" ],
724+ entity_data ["entity_key_proto" ],
725+ features ,
726+ )
727+ )
728+ return result
729+
459730
460731def _table_id (project : str , table : FeatureView ) -> str :
461732 return f"{ project } _{ table .name } "
0 commit comments