diff --git a/sdk/python/feast/arrow_error_handler.py b/sdk/python/feast/arrow_error_handler.py index e873592bd5d..e4862bb0982 100644 --- a/sdk/python/feast/arrow_error_handler.py +++ b/sdk/python/feast/arrow_error_handler.py @@ -30,6 +30,9 @@ def wrapper(*args, **kwargs): except Exception as e: if isinstance(e, FeastError): raise fl.FlightError(e.to_error_detail()) + # Re-raise non-Feast exceptions so Arrow Flight returns a proper error + # instead of allowing the server method to return None. + raise e return wrapper diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index f1ba4baa939..47e76a014f0 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -3,12 +3,13 @@ import uuid import warnings from dataclasses import asdict, dataclass -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import ( TYPE_CHECKING, Any, Callable, Dict, + KeysView, List, Optional, Tuple, @@ -151,10 +152,11 @@ def get_historical_features( config: RepoConfig, feature_views: List[FeatureView], feature_refs: List[str], - entity_df: Union[pandas.DataFrame, str, pyspark.sql.DataFrame], + entity_df: Optional[Union[pandas.DataFrame, str, pyspark.sql.DataFrame]], registry: BaseRegistry, project: str, full_feature_names: bool = False, + **kwargs, ) -> RetrievalJob: assert isinstance(config.offline_store, SparkOfflineStoreConfig) date_partition_column_formats = [] @@ -175,33 +177,75 @@ def get_historical_features( ) tmp_entity_df_table_name = offline_utils.get_temp_entity_table_name() - entity_schema = _get_entity_schema( - spark_session=spark_session, - entity_df=entity_df, - ) - event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( - entity_schema=entity_schema, - ) - entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( - entity_df, - event_timestamp_col, - spark_session, - ) - _upload_entity_df( - spark_session=spark_session, - table_name=tmp_entity_df_table_name, - entity_df=entity_df, - event_timestamp_col=event_timestamp_col, - ) + # Non-entity mode: synthesize a left table and timestamp range from start/end dates to avoid requiring entity_df. + # This makes date-range retrievals possible without enumerating entities upfront; sources remain bounded by time. + non_entity_mode = entity_df is None + if non_entity_mode: + # Why: derive bounded time window without requiring entities; uses max TTL fallback to constrain scans. + start_date, end_date = _compute_non_entity_dates(feature_views, kwargs) + entity_df_event_timestamp_range = (start_date, end_date) + + # Build query contexts so we can reuse entity names and per-view table info consistently. + fv_query_contexts = offline_utils.get_feature_view_query_context( + feature_refs, + feature_views, + registry, + project, + entity_df_event_timestamp_range, + ) - expected_join_keys = offline_utils.get_expected_join_keys( - project=project, feature_views=feature_views, registry=registry - ) - offline_utils.assert_expected_columns_in_entity_df( - entity_schema=entity_schema, - join_keys=expected_join_keys, - entity_df_event_timestamp_col=event_timestamp_col, - ) + # Collect the union of entity columns required across all feature views. + all_entities = _gather_all_entities(fv_query_contexts) + + # Build a UNION DISTINCT of per-feature-view entity projections, time-bounded and partition-pruned. + _create_temp_entity_union_view( + spark_session=spark_session, + tmp_view_name=tmp_entity_df_table_name, + feature_views=feature_views, + fv_query_contexts=fv_query_contexts, + start_date=start_date, + end_date=end_date, + date_partition_column_formats=date_partition_column_formats, + ) + + # Add a stable as-of timestamp column for PIT joins. + left_table_query_string, event_timestamp_col = _make_left_table_query( + end_date=end_date, tmp_view_name=tmp_entity_df_table_name + ) + entity_schema_keys = _entity_schema_keys_from( + all_entities=all_entities, event_timestamp_col=event_timestamp_col + ) + else: + entity_schema = _get_entity_schema( + spark_session=spark_session, + entity_df=entity_df, + ) + event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( + entity_schema=entity_schema, + ) + entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( + entity_df, + event_timestamp_col, + spark_session, + ) + _upload_entity_df( + spark_session=spark_session, + table_name=tmp_entity_df_table_name, + entity_df=entity_df, + event_timestamp_col=event_timestamp_col, + ) + left_table_query_string = tmp_entity_df_table_name + entity_schema_keys = cast(KeysView[str], entity_schema.keys()) + + if not non_entity_mode: + expected_join_keys = offline_utils.get_expected_join_keys( + project=project, feature_views=feature_views, registry=registry + ) + offline_utils.assert_expected_columns_in_entity_df( + entity_schema=entity_schema, + join_keys=expected_join_keys, + entity_df_event_timestamp_col=event_timestamp_col, + ) query_context = offline_utils.get_feature_view_query_context( feature_refs, @@ -232,9 +276,9 @@ def get_historical_features( feature_view_query_contexts=cast( List[offline_utils.FeatureViewQueryContext], spark_query_context ), - left_table_query_string=tmp_entity_df_table_name, + left_table_query_string=left_table_query_string, entity_df_event_timestamp_col=event_timestamp_col, - entity_df_columns=entity_schema.keys(), + entity_df_columns=entity_schema_keys, query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN, full_feature_names=full_feature_names, ) @@ -248,7 +292,7 @@ def get_historical_features( ), metadata=RetrievalMetadata( features=feature_refs, - keys=list(set(entity_schema.keys()) - {event_timestamp_col}), + keys=list(set(entity_schema_keys) - {event_timestamp_col}), min_event_timestamp=entity_df_event_timestamp_range[0], max_event_timestamp=entity_df_event_timestamp_range[1], ), @@ -540,6 +584,114 @@ def get_spark_session_or_start_new_with_repoconfig( return spark_session +def _compute_non_entity_dates( + feature_views: List[FeatureView], kwargs: Dict[str, Any] +) -> Tuple[datetime, datetime]: + # Why: bounds the scan window when no entity_df is provided using explicit dates or max TTL fallback. + start_date_opt = cast(Optional[datetime], kwargs.get("start_date")) + end_date_opt = cast(Optional[datetime], kwargs.get("end_date")) + end_date: datetime = end_date_opt or datetime.now(timezone.utc) + + if start_date_opt is None: + max_ttl_seconds = 0 + for fv in feature_views: + if fv.ttl and isinstance(fv.ttl, timedelta): + max_ttl_seconds = max(max_ttl_seconds, int(fv.ttl.total_seconds())) + start_date: datetime = ( + end_date - timedelta(seconds=max_ttl_seconds) + if max_ttl_seconds > 0 + else end_date - timedelta(days=30) + ) + else: + start_date = start_date_opt + return (start_date, end_date) + + +def _gather_all_entities( + fv_query_contexts: List[offline_utils.FeatureViewQueryContext], +) -> List[str]: + # Why: ensure a unified entity set across feature views to align UNION schemas. + all_entities: List[str] = [] + for ctx in fv_query_contexts: + for e in ctx.entities: + if e not in all_entities: + all_entities.append(e) + return all_entities + + +def _create_temp_entity_union_view( + spark_session: SparkSession, + tmp_view_name: str, + feature_views: List[FeatureView], + fv_query_contexts: List[offline_utils.FeatureViewQueryContext], + start_date: datetime, + end_date: datetime, + date_partition_column_formats: List[Optional[str]], +) -> None: + # Why: derive distinct entity keys observed in the time window without requiring an entity_df upfront. + start_date_str = _format_datetime(start_date) + end_date_str = _format_datetime(end_date) + + # Compute the unified entity set to align schemas in the UNION. + all_entities = _gather_all_entities(fv_query_contexts) + + per_view_selects: List[str] = [] + for fv, ctx, date_format in zip( + feature_views, fv_query_contexts, date_partition_column_formats + ): + assert isinstance(fv.batch_source, SparkSource) + from_expression = fv.batch_source.get_table_query_string() + timestamp_field = fv.batch_source.timestamp_field or "event_timestamp" + date_partition_column = fv.batch_source.date_partition_column + partition_clause = "" + if date_partition_column and date_format: + partition_clause = ( + f" AND {date_partition_column} >= '{start_date.strftime(date_format)}'" + f" AND {date_partition_column} <= '{end_date.strftime(date_format)}'" + ) + + # Fill missing entity columns with NULL and cast to STRING to keep UNION schemas aligned. + select_entities: List[str] = [] + ctx_entities_set = set(ctx.entities) + for col in all_entities: + if col in ctx_entities_set: + select_entities.append(f"CAST({col} AS STRING) AS {col}") + else: + select_entities.append(f"CAST(NULL AS STRING) AS {col}") + + per_view_selects.append( + f""" + SELECT DISTINCT {", ".join(select_entities)} + FROM {from_expression} + WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}'){partition_clause} + """ + ) + + union_query = "\nUNION DISTINCT\n".join([s.strip() for s in per_view_selects]) + spark_session.sql( + f"CREATE OR REPLACE TEMPORARY VIEW {tmp_view_name} AS {union_query}" + ) + + +def _make_left_table_query(end_date: datetime, tmp_view_name: str) -> Tuple[str, str]: + # Why: use a stable as-of timestamp for PIT joins when no entity timestamps are provided. + event_timestamp_col = "entity_ts" + left_table_query_string = ( + f"(SELECT *, TIMESTAMP('{_format_datetime(end_date)}') AS {event_timestamp_col} " + f"FROM {tmp_view_name})" + ) + return left_table_query_string, event_timestamp_col + + +def _entity_schema_keys_from( + all_entities: List[str], event_timestamp_col: str +) -> KeysView[str]: + # Why: pass a KeysView[str] to PIT query builder to match entity_df branch typing. + return cast( + KeysView[str], {k: None for k in (all_entities + [event_timestamp_col])}.keys() + ) + + def _get_entity_df_event_timestamp_range( entity_df: Union[pd.DataFrame, str], entity_df_event_timestamp_col: str, diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py index 6f2af7054b4..cd41921e56a 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py @@ -229,7 +229,8 @@ def get_table_query_string(self) -> str: # If both the table query string and the actual query are null, we can load from file. spark_session = SparkSession.getActiveSession() if spark_session is None: - raise AssertionError("Could not find an active spark session.") + # Remote mode may not have an active session bound to the thread; create one on demand. + spark_session = SparkSession.builder.getOrCreate() try: df = self._load_dataframe_from_path(spark_session) except Exception: diff --git a/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py index 938514a2ca0..22c75ebf387 100644 --- a/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py +++ b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py @@ -339,3 +339,168 @@ def _mock_entity(): value_type=ValueType.INT64, ) ] + + +@patch( + "feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig" +) +def test_get_historical_features_non_entity_with_date_range(mock_get_spark_session): + mock_spark_session = MagicMock() + # Return a DataFrame for any sql call; last call is used by RetrievalJob + final_df = MagicMock() + expected_pdf = pd.DataFrame([{"feature1": 1.0, "feature2": 2.0}]) + final_df.toPandas.return_value = expected_pdf + mock_spark_session.sql.return_value = final_df + mock_get_spark_session.return_value = mock_spark_session + + test_repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=SparkOfflineStoreConfig(type="spark"), + ) + + test_data_source1 = SparkSource( + name="test_nested_batch_source1", + description="test_nested_batch_source", + table="offline_store_database_name.offline_store_table_name1", + timestamp_field="nested_timestamp", + field_mapping={ + "event_header.event_published_datetime_utc": "nested_timestamp", + }, + date_partition_column="effective_date", + date_partition_column_format="%Y%m%d", + ) + + test_data_source2 = SparkSource( + name="test_nested_batch_source2", + description="test_nested_batch_source", + table="offline_store_database_name.offline_store_table_name2", + timestamp_field="nested_timestamp", + field_mapping={ + "event_header.event_published_datetime_utc": "nested_timestamp", + }, + date_partition_column="effective_date", + ) + + test_feature_view1 = FeatureView( + name="test_feature_view1", + entities=_mock_entity(), + schema=[ + Field(name="feature1", dtype=Float32), + ], + source=test_data_source1, + ) + + test_feature_view2 = FeatureView( + name="test_feature_view2", + entities=_mock_entity(), + schema=[ + Field(name="feature2", dtype=Float32), + ], + source=test_data_source2, + ) + + mock_registry = MagicMock() + start_date = datetime(2021, 1, 1) + end_date = datetime(2021, 1, 2) + retrieval_job = SparkOfflineStore.get_historical_features( + config=test_repo_config, + feature_views=[test_feature_view2, test_feature_view1], + feature_refs=["test_feature_view2:feature2", "test_feature_view1:feature1"], + entity_df=None, + registry=mock_registry, + project="test_project", + start_date=start_date, + end_date=end_date, + ) + + # Verify query bounded by end_date correctly in both date formats from the two sources + query = retrieval_job.query + assert "effective_date <= '2021-01-02'" in query + assert "effective_date <= '20210102'" in query + + # Verify data: the mocked Spark DataFrame flows through to Pandas + pdf = retrieval_job._to_df_internal() + assert pdf.equals(expected_pdf) + + +@patch( + "feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig" +) +def test_get_historical_features_non_entity_with_only_end_date(mock_get_spark_session): + mock_spark_session = MagicMock() + final_df = MagicMock() + expected_pdf = pd.DataFrame([{"feature1": 10.0, "feature2": 20.0}]) + final_df.toPandas.return_value = expected_pdf + mock_spark_session.sql.return_value = final_df + mock_get_spark_session.return_value = mock_spark_session + + test_repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=SparkOfflineStoreConfig(type="spark"), + ) + + test_data_source1 = SparkSource( + name="test_nested_batch_source1", + description="test_nested_batch_source", + table="offline_store_database_name.offline_store_table_name1", + timestamp_field="nested_timestamp", + field_mapping={ + "event_header.event_published_datetime_utc": "nested_timestamp", + }, + date_partition_column="effective_date", + date_partition_column_format="%Y%m%d", + ) + + test_data_source2 = SparkSource( + name="test_nested_batch_source2", + description="test_nested_batch_source", + table="offline_store_database_name.offline_store_table_name2", + timestamp_field="nested_timestamp", + field_mapping={ + "event_header.event_published_datetime_utc": "nested_timestamp", + }, + date_partition_column="effective_date", + ) + + test_feature_view1 = FeatureView( + name="test_feature_view1", + entities=_mock_entity(), + schema=[ + Field(name="feature1", dtype=Float32), + ], + source=test_data_source1, + ) + + test_feature_view2 = FeatureView( + name="test_feature_view2", + entities=_mock_entity(), + schema=[ + Field(name="feature2", dtype=Float32), + ], + source=test_data_source2, + ) + + mock_registry = MagicMock() + end_date = datetime(2021, 1, 2) + retrieval_job = SparkOfflineStore.get_historical_features( + config=test_repo_config, + feature_views=[test_feature_view2, test_feature_view1], + feature_refs=["test_feature_view2:feature2", "test_feature_view1:feature1"], + entity_df=None, + registry=mock_registry, + project="test_project", + end_date=end_date, + ) + + # Verify query bounded by end_date correctly for both sources + query = retrieval_job.query + assert "effective_date <= '2021-01-02'" in query + assert "effective_date <= '20210102'" in query + + # Verify data: mocked DataFrame flows to Pandas + pdf = retrieval_job._to_df_internal() + assert pdf.equals(expected_pdf)