diff --git a/docs/getting-started/architecture/feature-transformation.md b/docs/getting-started/architecture/feature-transformation.md index 1a15d4c3a51..562a733fef0 100644 --- a/docs/getting-started/architecture/feature-transformation.md +++ b/docs/getting-started/architecture/feature-transformation.md @@ -8,7 +8,7 @@ Feature transformations can be executed by three types of "transformation engine 1. The Feast Feature Server 2. An Offline Store (e.g., Snowflake, BigQuery, DuckDB, Spark, etc.) -3. A Stream processor (e.g., Flink or Spark Streaming) +3. [A Compute Engine](../../reference/compute-engine/README.md) The three transformation engines are coupled with the [communication pattern used for writes](write-patterns.md). diff --git a/docs/reference/compute-engine/README.md b/docs/reference/compute-engine/README.md new file mode 100644 index 00000000000..50aaa5befab --- /dev/null +++ b/docs/reference/compute-engine/README.md @@ -0,0 +1,119 @@ +# 🧠 ComputeEngine (WIP) + +The `ComputeEngine` is Feast’s pluggable abstraction for executing feature pipelines β€” including transformations, aggregations, joins, and materializations/get_historical_features β€” on a backend of your choice (e.g., Spark, PyArrow, Pandas, Ray). + +It powers both: + +- `materialize()` – for batch and stream generation of features to offline/online stores +- `get_historical_features()` – for point-in-time correct training dataset retrieval + +This system builds and executes DAGs (Directed Acyclic Graphs) of typed operations, enabling modular and scalable workflows. + +--- + +## 🧠 Core Concepts + +| Component | Description | +|--------------------|----------------------------------------------------------------------| +| `ComputeEngine` | Interface for executing materialization and retrieval tasks | +| `FeatureBuilder` | Constructs a DAG from Feature View definition for a specific backend | +| `DAGNode` | Represents a logical operation (read, aggregate, join, etc.) | +| `ExecutionPlan` | Executes nodes in dependency order and stores intermediate outputs | +| `ExecutionContext` | Holds config, registry, stores, entity data, and node outputs | + +--- + +## ✨ Available Engines + +### πŸ”₯ SparkComputeEngine + +- Distributed DAG execution via Apache Spark +- Supports point-in-time joins and large-scale materialization +- Integrates with `SparkOfflineStore` and `SparkMaterializationJob` + +### πŸ§ͺ LocalComputeEngine (WIP) + +- Runs on Arrow + Pandas (or optionally DuckDB) +- Designed for local dev, testing, or lightweight feature generation + +--- + +## πŸ› οΈ Feature Builder Flow +```markdown +SourceReadNode + | + v +JoinNode (Only for get_historical_features with entity df) + | + v +FilterNode (Always included; applies TTL or user-defined filters) + | + v +AggregationNode (If aggregations are defined in FeatureView) + | + v +DeduplicationNode (If no aggregation is defined for get_historical_features) + | + v +TransformationNode (If feature_transformation is defined) + | + v +ValidationNode (If enable_validation = True) + | + v +Output + β”œβ”€β”€> RetrievalOutput (For get_historical_features) + └──> OnlineStoreWrite / OfflineStoreWrite (For materialize) +``` + +Each step is implemented as a `DAGNode`. An `ExecutionPlan` executes these nodes in topological order, caching `DAGValue` outputs. + +--- + +## 🧩 Implementing a Custom Compute Engine + +To create your own compute engine: + +1. **Implement the interface** + +```python +from feast.infra.compute_engines.base import ComputeEngine +from feast.infra.materialization.batch_materialization_engine import MaterializationTask, MaterializationJob +from feast.infra.compute_engines.tasks import HistoricalRetrievalTask +class MyComputeEngine(ComputeEngine): + def materialize(self, task: MaterializationTask) -> MaterializationJob: + ... + + def get_historical_features(self, task: HistoricalRetrievalTask) -> RetrievalJob: + ... +``` + +2. Create a FeatureBuilder +```python +from feast.infra.compute_engines.feature_builder import FeatureBuilder + +class CustomFeatureBuilder(FeatureBuilder): + def build_source_node(self): ... + def build_aggregation_node(self, input_node): ... + def build_join_node(self, input_node): ... + def build_filter_node(self, input_node): + def build_dedup_node(self, input_node): + def build_transformation_node(self, input_node): ... + def build_output_nodes(self, input_node): ... +``` + +3. Define DAGNode subclasses + * ReadNode, AggregationNode, JoinNode, WriteNode, etc. + * Each DAGNode.execute(context) -> DAGValue + +4. Return an ExecutionPlan + * ExecutionPlan stores DAG nodes in topological order + * Automatically handles intermediate value caching + +## 🚧 Roadmap +- [x] Modular, backend-agnostic DAG execution framework +- [x] Spark engine with native support for materialization + PIT joins +- [ ] PyArrow + Pandas engine for local compute +- [ ] Native multi-feature-view DAG optimization +- [ ] DAG validation, metrics, and debug output +- [ ] Scalable distributed backend via Ray or Polars diff --git a/sdk/python/feast/batch_feature_view.py b/sdk/python/feast/batch_feature_view.py index 57d5aa1b07e..2441e4bc859 100644 --- a/sdk/python/feast/batch_feature_view.py +++ b/sdk/python/feast/batch_feature_view.py @@ -6,6 +6,7 @@ import dill from feast import flags_helper +from feast.aggregation import Aggregation from feast.data_source import DataSource from feast.entity import Entity from feast.feature_view import FeatureView @@ -40,7 +41,8 @@ class BatchFeatureView(FeatureView): schema: The schema of the feature view, including feature, timestamp, and entity columns. If not specified, can be inferred from the underlying data source. source: The batch source of data where this group of features is stored. - online: A boolean indicating whether online retrieval is enabled for this feature view. + online: A boolean indicating whether online retrieval and write to online store is enabled for this feature view. + offline: A boolean indicating whether offline retrieval and write to offline store is enabled for this feature view. description: A human-readable description. tags: A dictionary of key-value pairs to store arbitrary metadata. owner: The owner of the batch feature view, typically the email of the primary maintainer. @@ -55,6 +57,7 @@ class BatchFeatureView(FeatureView): entity_columns: List[Field] features: List[Field] online: bool + offline: bool description: str tags: Dict[str, str] owner: str @@ -63,6 +66,8 @@ class BatchFeatureView(FeatureView): udf: Optional[Callable[[Any], Any]] udf_string: Optional[str] feature_transformation: Transformation + batch_engine: Optional[Field] + aggregations: Optional[List[Aggregation]] def __init__( self, @@ -73,13 +78,16 @@ def __init__( entities: Optional[List[Entity]] = None, ttl: Optional[timedelta] = None, tags: Optional[Dict[str, str]] = None, - online: bool = True, + online: bool = False, + offline: bool = True, description: str = "", owner: str = "", schema: Optional[List[Field]] = None, udf: Optional[Callable[[Any], Any]], udf_string: Optional[str] = "", feature_transformation: Optional[Transformation] = None, + batch_engine: Optional[Field] = None, + aggregations: Optional[List[Aggregation]] = None, ): if not flags_helper.is_test(): warnings.warn( @@ -103,6 +111,8 @@ def __init__( self.feature_transformation = ( feature_transformation or self.get_feature_transformation() ) + self.batch_engine = batch_engine + self.aggregations = aggregations or [] super().__init__( name=name, @@ -110,6 +120,7 @@ def __init__( ttl=ttl, tags=tags, online=online, + offline=offline, description=description, owner=owner, schema=schema, @@ -144,6 +155,7 @@ def batch_feature_view( source: Optional[DataSource] = None, tags: Optional[Dict[str, str]] = None, online: bool = True, + offline: bool = True, description: str = "", owner: str = "", schema: Optional[List[Field]] = None, @@ -151,11 +163,13 @@ def batch_feature_view( """ Args: name: + mode: entities: ttl: source: tags: online: + offline: description: owner: schema: @@ -181,6 +195,7 @@ def decorator(user_function): source=source, tags=tags, online=online, + offline=offline, description=description, owner=owner, schema=schema, diff --git a/sdk/python/feast/feature_view.py b/sdk/python/feast/feature_view.py index 49b74893451..5259d5d2b90 100644 --- a/sdk/python/feast/feature_view.py +++ b/sdk/python/feast/feature_view.py @@ -93,6 +93,7 @@ class FeatureView(BaseFeatureView): entity_columns: List[Field] features: List[Field] online: bool + offline: bool description: str tags: Dict[str, str] owner: str @@ -107,6 +108,7 @@ def __init__( entities: Optional[List[Entity]] = None, ttl: Optional[timedelta] = timedelta(days=0), online: bool = True, + offline: bool = False, description: str = "", tags: Optional[Dict[str, str]] = None, owner: str = "", @@ -127,6 +129,8 @@ def __init__( can result in extremely computationally intensive queries. online (optional): A boolean indicating whether online retrieval is enabled for this feature view. + offline (optional): A boolean indicating whether write to offline store is enabled for + this feature view. description (optional): A human-readable description. tags (optional): A dictionary of key-value pairs to store arbitrary metadata. owner (optional): The owner of the feature view, typically the email of the @@ -218,6 +222,7 @@ def __init__( source=source, ) self.online = online + self.offline = offline self.materialization_intervals = [] def __hash__(self): diff --git a/sdk/python/feast/infra/compute_engines/base.py b/sdk/python/feast/infra/compute_engines/base.py new file mode 100644 index 00000000000..d5372d246aa --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/base.py @@ -0,0 +1,83 @@ +from abc import ABC +from typing import Union + +import pyarrow as pa + +from feast import RepoConfig +from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext +from feast.infra.compute_engines.tasks import HistoricalRetrievalTask +from feast.infra.materialization.batch_materialization_engine import ( + MaterializationJob, + MaterializationTask, +) +from feast.infra.offline_stores.offline_store import OfflineStore +from feast.infra.online_stores.online_store import OnlineStore +from feast.infra.registry.registry import Registry +from feast.utils import _get_column_names + + +class ComputeEngine(ABC): + """ + The interface that Feast uses to control the compute system that handles materialization and get_historical_features. + Each engine must implement: + - materialize(): to generate and persist features + - get_historical_features(): to perform point-in-time correct joins + Engines should use FeatureBuilder and DAGNode abstractions to build modular, pluggable workflows. + """ + + def __init__( + self, + *, + registry: Registry, + repo_config: RepoConfig, + offline_store: OfflineStore, + online_store: OnlineStore, + **kwargs, + ): + self.registry = registry + self.repo_config = repo_config + self.offline_store = offline_store + self.online_store = online_store + + def materialize(self, task: MaterializationTask) -> MaterializationJob: + raise NotImplementedError + + def get_historical_features(self, task: HistoricalRetrievalTask) -> pa.Table: + raise NotImplementedError + + def get_execution_context( + self, + task: Union[MaterializationTask, HistoricalRetrievalTask], + ) -> ExecutionContext: + entity_defs = [ + self.registry.get_entity(name, task.project) + for name in task.feature_view.entities + ] + entity_df = None + if hasattr(task, "entity_df") and task.entity_df is not None: + entity_df = task.entity_df + + column_info = self.get_column_info(task) + return ExecutionContext( + project=task.project, + repo_config=self.repo_config, + offline_store=self.offline_store, + online_store=self.online_store, + entity_defs=entity_defs, + column_info=column_info, + entity_df=entity_df, + ) + + def get_column_info( + self, + task: Union[MaterializationTask, HistoricalRetrievalTask], + ) -> ColumnInfo: + join_keys, feature_cols, ts_col, created_ts_col = _get_column_names( + task.feature_view, self.registry.list_entities(task.project) + ) + return ColumnInfo( + join_keys=join_keys, + feature_cols=feature_cols, + ts_col=ts_col, + created_ts_col=created_ts_col, + ) diff --git a/sdk/python/feast/infra/compute_engines/dag/context.py b/sdk/python/feast/infra/compute_engines/dag/context.py new file mode 100644 index 00000000000..8b170b67766 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/dag/context.py @@ -0,0 +1,67 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Union + +import pandas as pd + +from feast.entity import Entity +from feast.infra.compute_engines.dag.value import DAGValue +from feast.infra.offline_stores.offline_store import OfflineStore +from feast.infra.online_stores.online_store import OnlineStore +from feast.repo_config import RepoConfig + + +@dataclass +class ColumnInfo: + join_keys: List[str] + feature_cols: List[str] + ts_col: str + created_ts_col: Optional[str] + + def __iter__(self): + yield self.join_keys + yield self.feature_cols + yield self.ts_col + yield self.created_ts_col + + +@dataclass +class ExecutionContext: + """ + ExecutionContext holds all runtime information required to execute a DAG plan + within a ComputeEngine. It is passed into each DAGNode during execution and + contains shared context such as configuration, registry-backed entities, runtime + data (e.g. entity_df), and DAG evaluation state. + + Attributes: + project: Feast project name (namespace for features, entities, views). + + repo_config: Resolved RepoConfig containing provider and store configuration. + + offline_store: Reference to the configured OfflineStore implementation. + Used for loading raw feature data during materialization or retrieval. + + online_store: Reference to the OnlineStore implementation. + Used during materialization to write online features. + + entity_defs: List of Entity definitions fetched from the registry. + Used for resolving join keys, inferring timestamp columns, and + validating FeatureViews against schema. + + entity_df: A runtime DataFrame of entity rows used during historical + retrieval (e.g. for point-in-time join). Includes entity keys and + event timestamps. This is not part of the registry and is user-supplied + for training dataset generation. + + node_outputs: Internal cache of DAGValue outputs keyed by DAGNode name. + Automatically populated during ExecutionPlan execution to avoid redundant + computation. Used by downstream nodes to access their input data. + """ + + project: str + repo_config: RepoConfig + offline_store: OfflineStore + online_store: OnlineStore + column_info: ColumnInfo + entity_defs: List[Entity] + entity_df: Union[pd.DataFrame, None] = None + node_outputs: Dict[str, DAGValue] = field(default_factory=dict) diff --git a/sdk/python/feast/infra/compute_engines/dag/model.py b/sdk/python/feast/infra/compute_engines/dag/model.py new file mode 100644 index 00000000000..f77fdd0b6c9 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/dag/model.py @@ -0,0 +1,7 @@ +from enum import Enum + + +class DAGFormat(str, Enum): + SPARK = "spark" + PANDAS = "pandas" + ARROW = "arrow" diff --git a/sdk/python/feast/infra/compute_engines/dag/node.py b/sdk/python/feast/infra/compute_engines/dag/node.py new file mode 100644 index 00000000000..033ae8f1780 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/dag/node.py @@ -0,0 +1,47 @@ +from abc import ABC, abstractmethod +from typing import List + +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.value import DAGValue + + +class DAGNode(ABC): + name: str + inputs: List["DAGNode"] + outputs: List["DAGNode"] + + def __init__(self, name: str): + self.name = name + self.inputs = [] + self.outputs = [] + + def add_input(self, node: "DAGNode"): + if node in self.inputs: + raise ValueError(f"Input node {node.name} already added to {self.name}") + self.inputs.append(node) + node.outputs.append(self) + + def get_input_values(self, context: ExecutionContext) -> List[DAGValue]: + input_values = [] + for input_node in self.inputs: + if input_node.name not in context.node_outputs: + raise KeyError( + f"Missing output for input node '{input_node.name}' in context." + ) + input_values.append(context.node_outputs[input_node.name]) + return input_values + + def get_single_input_value(self, context: ExecutionContext) -> DAGValue: + if len(self.inputs) != 1: + raise RuntimeError( + f"DAGNode '{self.name}' expected exactly 1 input, but got {len(self.inputs)}." + ) + input_node = self.inputs[0] + if input_node.name not in context.node_outputs: + raise KeyError( + f"Missing output for input node '{input_node.name}' in context." + ) + return context.node_outputs[input_node.name] + + @abstractmethod + def execute(self, context: ExecutionContext) -> DAGValue: ... diff --git a/sdk/python/feast/infra/compute_engines/dag/plan.py b/sdk/python/feast/infra/compute_engines/dag/plan.py new file mode 100644 index 00000000000..130a894bda8 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/dag/plan.py @@ -0,0 +1,64 @@ +from typing import List + +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.node import DAGNode +from feast.infra.compute_engines.dag.value import DAGValue + + +class ExecutionPlan: + """ + ExecutionPlan represents an ordered sequence of DAGNodes that together define + a data processing pipeline for feature materialization or historical retrieval. + + This plan is constructed as a topological sort of the DAG β€” meaning that each + node appears after all its input dependencies. The plan is executed in order, + caching intermediate results (`DAGValue`) so that each node can reuse outputs + from upstream nodes without recomputation. + + Key Concepts: + - DAGNode: Each node performs a specific logical step (e.g., read, aggregate, join). + - DAGValue: Output of a node, includes data (e.g., Spark DataFrame) and metadata. + - ExecutionContext: Contains runtime information (config, registry, stores, entity_df). + - node_outputs: A cache of intermediate results keyed by node name. + + Usage: + plan = ExecutionPlan(dag_nodes) + result = plan.execute(context) + + This design enables modular compute backends (e.g., Spark, Pandas, Arrow), where + each node defines its execution logic independently while benefiting from shared + execution orchestration, caching, and context injection. + + Example: + DAG: + ReadNode -> AggregateNode -> JoinNode -> TransformNode -> WriteNode + + Execution proceeds step by step, passing intermediate DAGValues through + the plan while respecting node dependencies and formats. + + This approach is inspired by execution DAGs in systems like Apache Spark, + Apache Beam, and Dask β€” but specialized for Feast’s feature computation domain. + """ + + def __init__(self, nodes: List[DAGNode]): + self.nodes = nodes + + def execute(self, context: ExecutionContext) -> DAGValue: + context.node_outputs = {} + + for node in self.nodes: + for input_node in node.inputs: + if input_node.name not in context.node_outputs: + context.node_outputs[input_node.name] = input_node.execute(context) + + output = node.execute(context) + context.node_outputs[node.name] = output + + return context.node_outputs[self.nodes[-1].name] + + def to_sql(self, context: ExecutionContext) -> str: + """ + Generate SQL query for the entire execution plan. + This is a placeholder and should be implemented in subclasses. + """ + raise NotImplementedError("SQL generation is not implemented yet.") diff --git a/sdk/python/feast/infra/compute_engines/dag/value.py b/sdk/python/feast/infra/compute_engines/dag/value.py new file mode 100644 index 00000000000..0e2063d0dba --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/dag/value.py @@ -0,0 +1,14 @@ +from typing import Any, Optional + +from feast.infra.compute_engines.dag.model import DAGFormat + + +class DAGValue: + def __init__(self, data: Any, format: DAGFormat, metadata: Optional[dict] = None): + self.data = data + self.format = format + self.metadata = metadata or {} + + def assert_format(self, expected: DAGFormat): + if self.format != expected: + raise ValueError(f"Expected format {expected}, but got {self.format}") diff --git a/sdk/python/feast/infra/compute_engines/feature_builder.py b/sdk/python/feast/infra/compute_engines/feature_builder.py new file mode 100644 index 00000000000..cab32d47d26 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/feature_builder.py @@ -0,0 +1,95 @@ +from abc import ABC, abstractmethod +from typing import Union + +from feast import BatchFeatureView, FeatureView, StreamFeatureView +from feast.infra.compute_engines.dag.node import DAGNode +from feast.infra.compute_engines.dag.plan import ExecutionPlan +from feast.infra.compute_engines.tasks import HistoricalRetrievalTask +from feast.infra.materialization.batch_materialization_engine import MaterializationTask + + +class FeatureBuilder(ABC): + """ + Translates a FeatureView definition and execution task into an execution DAG. + This builder is engine-specific and returns an ExecutionPlan that ComputeEngine can run. + """ + + def __init__( + self, + feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], + task: Union[MaterializationTask, HistoricalRetrievalTask], + ): + self.feature_view = feature_view + self.task = task + self.nodes: list[DAGNode] = [] + + @abstractmethod + def build_source_node(self): + raise NotImplementedError + + @abstractmethod + def build_aggregation_node(self, input_node): + raise NotImplementedError + + @abstractmethod + def build_join_node(self, input_node): + raise NotImplementedError + + @abstractmethod + def build_filter_node(self, input_node): + raise NotImplementedError + + @abstractmethod + def build_dedup_node(self, input_node): + raise NotImplementedError + + @abstractmethod + def build_transformation_node(self, input_node): + raise NotImplementedError + + @abstractmethod + def build_output_nodes(self, input_node): + raise NotImplementedError + + @abstractmethod + def build_validation_node(self, input_node): + raise + + def _should_aggregate(self): + return ( + hasattr(self.feature_view, "aggregations") + and self.feature_view.aggregations is not None + and len(self.feature_view.aggregations) > 0 + ) + + def _should_transform(self): + return ( + hasattr(self.feature_view, "feature_transformation") + and self.feature_view.feature_transformation + ) + + def _should_validate(self): + return getattr(self.feature_view, "enable_validation", False) + + def build(self) -> ExecutionPlan: + last_node = self.build_source_node() + + # PIT join entities to the feature data, and perform filtering + if isinstance(self.task, HistoricalRetrievalTask): + last_node = self.build_join_node(last_node) + + last_node = self.build_filter_node(last_node) + + if self._should_aggregate(): + last_node = self.build_aggregation_node(last_node) + elif isinstance(self.task, HistoricalRetrievalTask): + last_node = self.build_dedup_node(last_node) + + if self._should_transform(): + last_node = self.build_transformation_node(last_node) + + if self._should_validate(): + last_node = self.build_validation_node(last_node) + + self.build_output_nodes(last_node) + return ExecutionPlan(self.nodes) diff --git a/sdk/python/feast/infra/compute_engines/spark/compute.py b/sdk/python/feast/infra/compute_engines/spark/compute.py new file mode 100644 index 00000000000..e6e6cc52971 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/spark/compute.py @@ -0,0 +1,94 @@ +from feast.infra.compute_engines.base import ComputeEngine +from feast.infra.compute_engines.spark.feature_builder import SparkFeatureBuilder +from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob +from feast.infra.compute_engines.spark.utils import get_or_create_new_spark_session +from feast.infra.compute_engines.tasks import HistoricalRetrievalTask +from feast.infra.materialization.batch_materialization_engine import ( + MaterializationJob, + MaterializationJobStatus, + MaterializationTask, +) +from feast.infra.materialization.contrib.spark.spark_materialization_engine import ( + SparkMaterializationJob, +) +from feast.infra.offline_stores.offline_store import RetrievalJob + + +class SparkComputeEngine(ComputeEngine): + def __init__( + self, + offline_store, + online_store, + registry, + repo_config, + **kwargs, + ): + super().__init__( + offline_store=offline_store, + online_store=online_store, + registry=registry, + repo_config=repo_config, + **kwargs, + ) + self.spark_session = get_or_create_new_spark_session() + + def materialize(self, task: MaterializationTask) -> MaterializationJob: + job_id = f"{task.feature_view.name}-{task.start_time}-{task.end_time}" + + # βœ… 1. Build typed execution context + context = self.get_execution_context(task) + + try: + # βœ… 2. Construct Feature Builder and run it + builder = SparkFeatureBuilder( + spark_session=self.spark_session, + feature_view=task.feature_view, + task=task, + ) + plan = builder.build() + plan.execute(context) + + # βœ… 3. Report success + return SparkMaterializationJob( + job_id=job_id, status=MaterializationJobStatus.SUCCEEDED + ) + + except Exception as e: + # πŸ›‘ Handle failure + return SparkMaterializationJob( + job_id=job_id, status=MaterializationJobStatus.ERROR, error=e + ) + + def get_historical_features(self, task: HistoricalRetrievalTask) -> RetrievalJob: + if isinstance(task.entity_df, str): + raise NotImplementedError("SQL-based entity_df is not yet supported in DAG") + + # βœ… 1. Build typed execution context + context = self.get_execution_context(task) + + try: + # βœ… 2. Construct Feature Builder and run it + builder = SparkFeatureBuilder( + spark_session=self.spark_session, + feature_view=task.feature_view, + task=task, + ) + plan = builder.build() + + return SparkDAGRetrievalJob( + plan=plan, + spark_session=self.spark_session, + context=context, + config=self.repo_config, + full_feature_names=task.full_feature_name, + ) + except Exception as e: + # πŸ›‘ Handle failure + return SparkDAGRetrievalJob( + plan=None, + spark_session=self.spark_session, + context=context, + config=self.repo_config, + full_feature_names=task.full_feature_name, + error=e, + ) diff --git a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py new file mode 100644 index 00000000000..e7efbfe1195 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py @@ -0,0 +1,89 @@ +from typing import Union + +from pyspark.sql import SparkSession + +from feast import BatchFeatureView, FeatureView, StreamFeatureView +from feast.infra.compute_engines.base import HistoricalRetrievalTask +from feast.infra.compute_engines.feature_builder import FeatureBuilder +from feast.infra.compute_engines.spark.node import ( + SparkAggregationNode, + SparkDedupNode, + SparkFilterNode, + SparkHistoricalRetrievalReadNode, + SparkJoinNode, + SparkMaterializationReadNode, + SparkTransformationNode, + SparkWriteNode, +) +from feast.infra.materialization.batch_materialization_engine import MaterializationTask + + +class SparkFeatureBuilder(FeatureBuilder): + def __init__( + self, + spark_session: SparkSession, + feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], + task: Union[MaterializationTask, HistoricalRetrievalTask], + ): + super().__init__(feature_view, task) + self.spark_session = spark_session + + def build_source_node(self): + if isinstance(self.task, MaterializationTask): + node = SparkMaterializationReadNode("source", self.task) + else: + node = SparkHistoricalRetrievalReadNode( + "source", self.task, self.spark_session + ) + self.nodes.append(node) + return node + + def build_aggregation_node(self, input_node): + agg_specs = self.feature_view.aggregations + group_by_keys = self.feature_view.entities + timestamp_col = self.feature_view.batch_source.timestamp_field + node = SparkAggregationNode( + "agg", input_node, agg_specs, group_by_keys, timestamp_col + ) + self.nodes.append(node) + return node + + def build_join_node(self, input_node): + join_keys = self.feature_view.entities + node = SparkJoinNode( + "join", input_node, join_keys, self.feature_view, self.spark_session + ) + self.nodes.append(node) + return node + + def build_filter_node(self, input_node): + filter_expr = None + if hasattr(self.feature_view, "filter"): + filter_expr = self.feature_view.filter + node = SparkFilterNode( + "filter", self.spark_session, input_node, self.feature_view, filter_expr + ) + self.nodes.append(node) + return node + + def build_dedup_node(self, input_node): + node = SparkDedupNode( + "dedup", input_node, self.feature_view, self.spark_session + ) + self.nodes.append(node) + return node + + def build_transformation_node(self, input_node): + udf_name = self.feature_view.feature_transformation.name + udf = self.feature_view.feature_transformation.udf + node = SparkTransformationNode(udf_name, input_node, udf) + self.nodes.append(node) + return node + + def build_output_nodes(self, input_node): + node = SparkWriteNode("output", input_node, self.feature_view) + self.nodes.append(node) + return node + + def build_validation_node(self, input_node): + pass diff --git a/sdk/python/feast/infra/compute_engines/spark/job.py b/sdk/python/feast/infra/compute_engines/spark/job.py new file mode 100644 index 00000000000..0f343789d96 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/spark/job.py @@ -0,0 +1,56 @@ +from typing import List, Optional + +import pyspark +from pyspark.sql import SparkSession + +from feast import OnDemandFeatureView, RepoConfig +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.plan import ExecutionPlan +from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( + SparkRetrievalJob, +) +from feast.infra.offline_stores.offline_store import RetrievalMetadata + + +class SparkDAGRetrievalJob(SparkRetrievalJob): + def __init__( + self, + spark_session: SparkSession, + plan: Optional[ExecutionPlan], + context: ExecutionContext, + full_feature_names: bool, + config: RepoConfig, + on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None, + metadata: Optional[RetrievalMetadata] = None, + error: Optional[BaseException] = None, + ): + super().__init__( + spark_session=spark_session, + query="", + full_feature_names=full_feature_names, + config=config, + on_demand_feature_views=on_demand_feature_views, + metadata=metadata, + ) + self._plan = plan + self._context = context + self._metadata = metadata + self._spark_df = None + self._error = error + + def error(self) -> Optional[BaseException]: + return self._error + + def _ensure_executed(self): + if self._spark_df is None: + result = self._plan.execute(self._context) + self._spark_df = result.data + + def to_spark_df(self) -> pyspark.sql.DataFrame: + self._ensure_executed() + assert self._spark_df is not None, "Execution plan did not produce a DataFrame" + return self._spark_df + + def to_sql(self) -> str: + assert self._plan is not None, "Execution plan is not set" + return self._plan.to_sql(self._context) diff --git a/sdk/python/feast/infra/compute_engines/spark/node.py b/sdk/python/feast/infra/compute_engines/spark/node.py new file mode 100644 index 00000000000..e3f737a4fa6 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/spark/node.py @@ -0,0 +1,414 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Dict, List, Optional, Union, cast + +from pyspark.sql import DataFrame, SparkSession, Window +from pyspark.sql import functions as F + +from feast import BatchFeatureView, StreamFeatureView +from feast.aggregation import Aggregation +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.model import DAGFormat +from feast.infra.compute_engines.dag.node import DAGNode +from feast.infra.compute_engines.dag.value import DAGValue +from feast.infra.compute_engines.tasks import HistoricalRetrievalTask +from feast.infra.materialization.batch_materialization_engine import MaterializationTask +from feast.infra.materialization.contrib.spark.spark_materialization_engine import ( + _map_by_partition, + _SparkSerializedArtifacts, +) +from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( + SparkRetrievalJob, + _get_entity_schema, +) +from feast.infra.offline_stores.offline_utils import ( + infer_event_timestamp_from_entity_df, +) +from feast.utils import _get_fields_with_aliases + +ENTITY_TS_ALIAS = "__entity_event_timestamp" + + +# Rename entity_df event_timestamp_col to match feature_df +def rename_entity_ts_column( + spark_session: SparkSession, entity_df: DataFrame +) -> DataFrame: + # check if entity_ts_alias already exists + if ENTITY_TS_ALIAS in entity_df.columns: + return entity_df + + entity_schema = _get_entity_schema( + spark_session=spark_session, + entity_df=entity_df, + ) + event_timestamp_col = infer_event_timestamp_from_entity_df( + entity_schema=entity_schema, + ) + if not isinstance(entity_df, DataFrame): + entity_df = spark_session.createDataFrame(entity_df) + entity_df = entity_df.withColumnRenamed(event_timestamp_col, ENTITY_TS_ALIAS) + return entity_df + + +@dataclass +class SparkJoinContext: + name: str # feature view name or alias + join_keys: List[str] + feature_columns: List[str] + timestamp_field: str + created_timestamp_column: Optional[str] + ttl_seconds: Optional[int] + min_event_timestamp: Optional[datetime] + max_event_timestamp: Optional[datetime] + field_mapping: Dict[str, str] # original_column_name -> renamed_column + full_feature_names: bool = False # apply feature view name prefix + + +class SparkMaterializationReadNode(DAGNode): + def __init__( + self, name: str, task: Union[MaterializationTask, HistoricalRetrievalTask] + ): + super().__init__(name) + self.task = task + + def execute(self, context: ExecutionContext) -> DAGValue: + offline_store = context.offline_store + start_time = self.task.start_time + end_time = self.task.end_time + + ( + join_key_columns, + feature_name_columns, + timestamp_field, + created_timestamp_column, + ) = context.column_info + + # πŸ“₯ Reuse Feast's robust query resolver + retrieval_job = offline_store.pull_latest_from_table_or_query( + config=context.repo_config, + data_source=self.task.feature_view.batch_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_time, + end_date=end_time, + ) + spark_df = cast(SparkRetrievalJob, retrieval_job).to_spark_df() + + return DAGValue( + data=spark_df, + format=DAGFormat.SPARK, + metadata={ + "source": "feature_view_batch_source", + "timestamp_field": timestamp_field, + "created_timestamp_column": created_timestamp_column, + "start_date": start_time, + "end_date": end_time, + }, + ) + + +class SparkHistoricalRetrievalReadNode(DAGNode): + def __init__( + self, name: str, task: HistoricalRetrievalTask, spark_session: SparkSession + ): + super().__init__(name) + self.task = task + self.spark_session = spark_session + + def execute(self, context: ExecutionContext) -> DAGValue: + """ + Read data from the offline store on the Spark engine. + TODO: Some functionality is duplicated with SparkMaterializationReadNode and spark get_historical_features. + Args: + context: SparkExecutionContext + Returns: DAGValue + """ + fv = self.task.feature_view + source = fv.batch_source + + ( + join_key_columns, + feature_name_columns, + timestamp_field, + created_timestamp_column, + ) = context.column_info + + # TODO: Use pull_all_from_table_or_query when it supports not filtering by timestamp + # retrieval_job = offline_store.pull_all_from_table_or_query( + # config=context.repo_config, + # data_source=source, + # join_key_columns=join_key_columns, + # feature_name_columns=feature_name_columns, + # timestamp_field=timestamp_field, + # start_date=min_ts, + # end_date=max_ts, + # ) + # spark_df = cast(SparkRetrievalJob, retrieval_job).to_spark_df() + + columns = join_key_columns + feature_name_columns + [timestamp_field] + if created_timestamp_column: + columns.append(created_timestamp_column) + + (fields_with_aliases, aliases) = _get_fields_with_aliases( + fields=columns, + field_mappings=source.field_mapping, + ) + fields_with_alias_string = ", ".join(fields_with_aliases) + + from_expression = source.get_table_query_string() + + query = f""" + SELECT {fields_with_alias_string} + FROM {from_expression} + """ + spark_df = self.spark_session.sql(query) + + return DAGValue( + data=spark_df, + format=DAGFormat.SPARK, + metadata={ + "source": "feature_view_batch_source", + "timestamp_field": timestamp_field, + }, + ) + + +class SparkAggregationNode(DAGNode): + def __init__( + self, + name: str, + input_node: DAGNode, + aggregations: List[Aggregation], + group_by_keys: List[str], + timestamp_col: str, + ): + super().__init__(name) + self.add_input(input_node) + self.aggregations = aggregations + self.group_by_keys = group_by_keys + self.timestamp_col = timestamp_col + + def execute(self, context: ExecutionContext) -> DAGValue: + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.SPARK) + input_df: DataFrame = input_value.data + + agg_exprs = [] + for agg in self.aggregations: + func = getattr(F, agg.function) + expr = func(agg.column).alias( + f"{agg.function}_{agg.column}_{int(agg.time_window.total_seconds())}s" + if agg.time_window + else f"{agg.function}_{agg.column}" + ) + agg_exprs.append(expr) + + if any(agg.time_window for agg in self.aggregations): + # πŸ•’ Use Spark's `window` function + time_window = self.aggregations[ + 0 + ].time_window # assume consistent window size for now + if time_window is None: + raise ValueError("Aggregation requires time_window but got None.") + window_duration_str = f"{int(time_window.total_seconds())} seconds" + + grouped = input_df.groupBy( + *self.group_by_keys, + F.window(F.col(self.timestamp_col), window_duration_str), + ).agg(*agg_exprs) + else: + # Simple aggregation + grouped = input_df.groupBy( + *self.group_by_keys, + ).agg(*agg_exprs) + + return DAGValue( + data=grouped, format=DAGFormat.SPARK, metadata={"aggregated": True} + ) + + +class SparkJoinNode(DAGNode): + def __init__( + self, + name: str, + feature_node: DAGNode, + join_keys: List[str], + feature_view: Union[BatchFeatureView, StreamFeatureView], + spark_session: SparkSession, + ): + super().__init__(name) + self.join_keys = join_keys + self.add_input(feature_node) + self.feature_view = feature_view + self.spark_session = spark_session + + def execute(self, context: ExecutionContext) -> DAGValue: + feature_value = self.get_single_input_value(context) + feature_value.assert_format(DAGFormat.SPARK) + feature_df: DataFrame = feature_value.data + + entity_df = context.entity_df + assert entity_df is not None, "entity_df must be set in ExecutionContext" + + # Get timestamp fields from feature view + join_keys, feature_cols, ts_col, created_ts_col = context.column_info + + # Rename entity_df event_timestamp_col to match feature_df + entity_df = rename_entity_ts_column( + spark_session=self.spark_session, + entity_df=entity_df, + ) + + # Perform left join on entity df + joined = feature_df.join(entity_df, on=join_keys, how="left") + + return DAGValue( + data=joined, format=DAGFormat.SPARK, metadata={"joined_on": join_keys} + ) + + +class SparkFilterNode(DAGNode): + def __init__( + self, + name: str, + spark_session: SparkSession, + input_node: DAGNode, + feature_view: Union[BatchFeatureView, StreamFeatureView], + filter_condition: Optional[str] = None, + ): + super().__init__(name) + self.spark_session = spark_session + self.feature_view = feature_view + self.add_input(input_node) + self.filter_condition = filter_condition + + def execute(self, context: ExecutionContext) -> DAGValue: + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.SPARK) + input_df: DataFrame = input_value.data + + # Get timestamp fields from feature view + _, _, ts_col, _ = context.column_info + + # Optional filter: feature.ts <= entity.event_timestamp + filtered_df = input_df + if ENTITY_TS_ALIAS in input_df.columns: + filtered_df = filtered_df.filter(F.col(ts_col) <= F.col(ENTITY_TS_ALIAS)) + + # Optional TTL filter: feature.ts >= entity.event_timestamp - ttl + if self.feature_view.ttl: + ttl_seconds = int(self.feature_view.ttl.total_seconds()) + lower_bound = F.col(ENTITY_TS_ALIAS) - F.expr( + f"INTERVAL {ttl_seconds} seconds" + ) + filtered_df = filtered_df.filter(F.col(ts_col) >= lower_bound) + + # Optional custom filter condition + if self.filter_condition: + filtered_df = filtered_df.filter(self.filter_condition) + + return DAGValue( + data=filtered_df, + format=DAGFormat.SPARK, + metadata={"filter_applied": True}, + ) + + +class SparkDedupNode(DAGNode): + def __init__( + self, + name: str, + input_node: DAGNode, + feature_view: Union[BatchFeatureView, StreamFeatureView], + spark_session: SparkSession, + ): + super().__init__(name) + self.add_input(input_node) + self.feature_view = feature_view + self.spark_session = spark_session + + def execute(self, context: ExecutionContext) -> DAGValue: + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.SPARK) + input_df: DataFrame = input_value.data + + # Get timestamp fields from feature view + join_keys, _, ts_col, created_ts_col = context.column_info + + # Dedup based on join keys and event timestamp + # Dedup with row_number + partition_cols = join_keys + [ENTITY_TS_ALIAS] + ordering = [F.col(ts_col).desc()] + if created_ts_col: + ordering.append(F.col(created_ts_col).desc()) + + window = Window.partitionBy(*partition_cols).orderBy(*ordering) + deduped_df = ( + input_df.withColumn("row_num", F.row_number().over(window)) + .filter("row_num = 1") + .drop("row_num") + ) + + return DAGValue( + data=deduped_df, + format=DAGFormat.SPARK, + metadata={"deduped": True}, + ) + + +class SparkWriteNode(DAGNode): + def __init__( + self, + name: str, + input_node: DAGNode, + feature_view: Union[BatchFeatureView, StreamFeatureView], + ): + super().__init__(name) + self.add_input(input_node) + self.feature_view = feature_view + + def execute(self, context: ExecutionContext) -> DAGValue: + spark_df: DataFrame = self.get_single_input_value(context).data + spark_serialized_artifacts = _SparkSerializedArtifacts.serialize( + feature_view=self.feature_view, repo_config=context.repo_config + ) + + # βœ… 1. Write to offline store (if enabled) + if self.feature_view.offline: + # TODO: Update _map_by_partition to be able to write to offline store + pass + + # βœ… 2. Write to online store (if enabled) + if self.feature_view.online: + spark_df.mapInPandas( + lambda x: _map_by_partition(x, spark_serialized_artifacts), "status int" + ).count() + + return DAGValue( + data=spark_df, + format=DAGFormat.SPARK, + metadata={ + "feature_view": self.feature_view.name, + "write_to_online": self.feature_view.online, + "write_to_offline": self.feature_view.offline, + }, + ) + + +class SparkTransformationNode(DAGNode): + def __init__(self, name: str, input_node: DAGNode, udf): + super().__init__(name) + self.add_input(input_node) + self.udf = udf + + def execute(self, context: ExecutionContext) -> DAGValue: + input_val = self.get_single_input_value(context) + input_val.assert_format(DAGFormat.SPARK) + + transformed_df = self.udf(input_val.data) + + return DAGValue( + data=transformed_df, format=DAGFormat.SPARK, metadata={"transformed": True} + ) diff --git a/sdk/python/feast/infra/compute_engines/tasks.py b/sdk/python/feast/infra/compute_engines/tasks.py new file mode 100644 index 00000000000..a5b5583b3ce --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/tasks.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Union + +import pandas as pd + +from feast import BatchFeatureView, StreamFeatureView +from feast.infra.registry.registry import Registry + + +@dataclass +class HistoricalRetrievalTask: + project: str + entity_df: Union[pd.DataFrame, str] + feature_view: Union[BatchFeatureView, StreamFeatureView] + full_feature_name: bool + registry: Registry + start_time: datetime + end_time: datetime diff --git a/sdk/python/feast/stream_feature_view.py b/sdk/python/feast/stream_feature_view.py index 2f134001a5a..e3608b10354 100644 --- a/sdk/python/feast/stream_feature_view.py +++ b/sdk/python/feast/stream_feature_view.py @@ -57,7 +57,8 @@ class StreamFeatureView(FeatureView): aggregations: List of aggregations registered with the stream feature view. mode: The mode of execution. timestamp_field: Must be specified if aggregations are specified. Defines the timestamp column on which to aggregate windows. - online: A boolean indicating whether online retrieval is enabled for this feature view. + online: A boolean indicating whether online retrieval, and write to online store is enabled for this feature view. + offline: A boolean indicating whether offline retrieval, and write to offline store is enabled for this feature view. description: A human-readable description. tags: A dictionary of key-value pairs to store arbitrary metadata. owner: The owner of the stream feature view, typically the email of the primary maintainer. @@ -72,6 +73,7 @@ class StreamFeatureView(FeatureView): entity_columns: List[Field] features: List[Field] online: bool + offline: bool description: str tags: Dict[str, str] owner: str @@ -82,6 +84,7 @@ class StreamFeatureView(FeatureView): udf: Optional[FunctionType] udf_string: Optional[str] feature_transformation: Optional[Transformation] + stream_engine: Optional[Field] def __init__( self, @@ -92,6 +95,7 @@ def __init__( ttl: timedelta = timedelta(days=0), tags: Optional[Dict[str, str]] = None, online: bool = True, + offline: bool = False, description: str = "", owner: str = "", schema: Optional[List[Field]] = None, @@ -101,6 +105,7 @@ def __init__( udf: Optional[FunctionType] = None, udf_string: Optional[str] = "", feature_transformation: Optional[Transformation] = None, + stream_engine: Optional[Field] = None, ): if not flags_helper.is_test(): warnings.warn( @@ -131,6 +136,7 @@ def __init__( self.feature_transformation = ( feature_transformation or self.get_feature_transformation() ) + self.stream_engine = stream_engine super().__init__( name=name, @@ -138,6 +144,7 @@ def __init__( ttl=ttl, tags=tags, online=online, + offline=offline, description=description, owner=owner, schema=schema, diff --git a/sdk/python/feast/transformation/base.py b/sdk/python/feast/transformation/base.py index b02be0a6708..8ff1925d0e0 100644 --- a/sdk/python/feast/transformation/base.py +++ b/sdk/python/feast/transformation/base.py @@ -84,7 +84,7 @@ def __init__( self.mode = mode self.udf = udf self.udf_string = udf_string - self.name = name + self.name = name or udf.__name__ self.tags = tags or {} self.description = description self.owner = owner diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 4cca1379ed3..1520c2f7dd5 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -271,6 +271,7 @@ def _convert_arrow_fv_to_proto( if isinstance(table, pyarrow.Table): table = table.to_batches()[0] + # TODO: This will break if the feature view has aggregations or transformations columns = [ (field.name, field.dtype.to_value_type()) for field in feature_view.features ] + list(join_keys.items()) diff --git a/sdk/python/tests/integration/compute_engines/spark/test_compute.py b/sdk/python/tests/integration/compute_engines/spark/test_compute.py new file mode 100644 index 00000000000..b8046c12296 --- /dev/null +++ b/sdk/python/tests/integration/compute_engines/spark/test_compute.py @@ -0,0 +1,252 @@ +from datetime import datetime, timedelta +from typing import cast +from unittest.mock import MagicMock + +import pandas as pd +import pytest +from pyspark.sql import DataFrame +from tqdm import tqdm + +from feast import BatchFeatureView, Entity, Field +from feast.aggregation import Aggregation +from feast.data_source import DataSource +from feast.infra.compute_engines.spark.compute import SparkComputeEngine +from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob +from feast.infra.compute_engines.tasks import HistoricalRetrievalTask +from feast.infra.materialization.batch_materialization_engine import ( + MaterializationJobStatus, + MaterializationTask, +) +from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( + SparkOfflineStore, +) +from feast.infra.offline_stores.contrib.spark_offline_store.tests.data_source import ( + SparkDataSourceCreator, +) +from feast.types import Float32, Int32, Int64 +from tests.integration.feature_repos.integration_test_repo_config import ( + IntegrationTestRepoConfig, +) +from tests.integration.feature_repos.repo_configuration import ( + construct_test_environment, +) +from tests.integration.feature_repos.universal.online_store.redis import ( + RedisOnlineStoreCreator, +) + +now = datetime.now() +today = datetime.today() + +driver = Entity( + name="driver_id", + description="driver id", +) + + +def create_feature_dataset(spark_environment) -> DataSource: + yesterday = today - timedelta(days=1) + last_week = today - timedelta(days=7) + df = pd.DataFrame( + [ + { + "driver_id": 1001, + "event_timestamp": yesterday, + "created": now - timedelta(hours=2), + "conv_rate": 0.8, + "acc_rate": 0.5, + "avg_daily_trips": 15, + }, + { + "driver_id": 1001, + "event_timestamp": last_week, + "created": now - timedelta(hours=3), + "conv_rate": 0.75, + "acc_rate": 0.9, + "avg_daily_trips": 14, + }, + { + "driver_id": 1002, + "event_timestamp": yesterday, + "created": now - timedelta(hours=2), + "conv_rate": 0.7, + "acc_rate": 0.4, + "avg_daily_trips": 12, + }, + { + "driver_id": 1002, + "event_timestamp": yesterday - timedelta(days=1), + "created": now - timedelta(hours=2), + "conv_rate": 0.3, + "acc_rate": 0.6, + "avg_daily_trips": 12, + }, + ] + ) + ds = spark_environment.data_source_creator.create_data_source( + df, + spark_environment.feature_store.project, + timestamp_field="event_timestamp", + created_timestamp_column="created", + ) + return ds + + +def create_entity_df() -> pd.DataFrame: + entity_df = pd.DataFrame( + [ + {"driver_id": 1001, "event_timestamp": today}, + {"driver_id": 1002, "event_timestamp": today}, + ] + ) + return entity_df + + +def create_spark_environment(): + spark_config = IntegrationTestRepoConfig( + provider="local", + online_store_creator=RedisOnlineStoreCreator, + offline_store_creator=SparkDataSourceCreator, + batch_engine={"type": "spark.engine", "partitions": 10}, + ) + spark_environment = construct_test_environment( + spark_config, None, entity_key_serialization_version=2 + ) + spark_environment.setup() + return spark_environment + + +@pytest.mark.integration +def test_spark_compute_engine_get_historical_features(): + spark_environment = create_spark_environment() + fs = spark_environment.feature_store + registry = fs.registry + data_source = create_feature_dataset(spark_environment) + + def transform_feature(df: DataFrame) -> DataFrame: + df = df.withColumn("sum_conv_rate", df["sum_conv_rate"] * 2) + df = df.withColumn("avg_acc_rate", df["avg_acc_rate"] * 2) + return df + + driver_stats_fv = BatchFeatureView( + name="driver_hourly_stats", + entities=[driver], + mode="python", + aggregations=[ + Aggregation(column="conv_rate", function="sum"), + Aggregation(column="acc_rate", function="avg"), + ], + udf=transform_feature, + udf_string="transform_feature", + ttl=timedelta(days=3), + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + Field(name="driver_id", dtype=Int32), + ], + online=False, + offline=False, + source=data_source, + ) + + entity_df = create_entity_df() + + try: + fs.apply([driver, driver_stats_fv]) + + # πŸ›  Build retrieval task + task = HistoricalRetrievalTask( + project=spark_environment.project, + entity_df=entity_df, + feature_view=driver_stats_fv, + full_feature_name=False, + registry=registry, + start_time=now - timedelta(days=1), + end_time=now, + ) + + # πŸ§ͺ Run SparkComputeEngine + engine = SparkComputeEngine( + repo_config=spark_environment.config, + offline_store=SparkOfflineStore(), + online_store=MagicMock(), + registry=registry, + ) + + spark_dag_retrieval_job = engine.get_historical_features(task) + spark_df = cast(SparkDAGRetrievalJob, spark_dag_retrieval_job).to_spark_df() + df_out = spark_df.orderBy("driver_id").to_pandas_on_spark() + + # βœ… Assert output + assert df_out.driver_id.to_list() == [1001, 1002] + assert abs(df_out["sum_conv_rate"].to_list()[0] - 1.6) < 1e-6 + assert abs(df_out["sum_conv_rate"].to_list()[1] - 2.0) < 1e-6 + assert abs(df_out["avg_acc_rate"].to_list()[0] - 1.0) < 1e-6 + assert abs(df_out["avg_acc_rate"].to_list()[1] - 1.0) < 1e-6 + + finally: + spark_environment.teardown() + + +@pytest.mark.integration +def test_spark_compute_engine_materialize(): + spark_environment = create_spark_environment() + fs = spark_environment.feature_store + registry = fs.registry + + data_source = create_feature_dataset(spark_environment) + + def transform_feature(df: DataFrame) -> DataFrame: + df = df.withColumn("conv_rate", df["conv_rate"] * 2) + df = df.withColumn("acc_rate", df["acc_rate"] * 2) + return df + + driver_stats_fv = BatchFeatureView( + name="driver_hourly_stats", + entities=[driver], + mode="python", + udf=transform_feature, + udf_string="transform_feature", + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + Field(name="driver_id", dtype=Int32), + ], + online=True, + offline=False, + source=data_source, + ) + + def tqdm_builder(length): + return tqdm(total=length, ncols=100) + + try: + fs.apply([driver, driver_stats_fv]) + + # πŸ›  Build retrieval task + task = MaterializationTask( + project=spark_environment.project, + feature_view=driver_stats_fv, + start_time=now - timedelta(days=1), + end_time=now, + tqdm_builder=tqdm_builder, + ) + + # πŸ§ͺ Run SparkComputeEngine + engine = SparkComputeEngine( + repo_config=spark_environment.config, + offline_store=SparkOfflineStore(), + online_store=MagicMock(), + registry=registry, + ) + + spark_materialize_job = engine.materialize(task) + + assert spark_materialize_job.status() == MaterializationJobStatus.SUCCEEDED + finally: + spark_environment.teardown() + + +if __name__ == "__main__": + test_spark_compute_engine_get_historical_features() diff --git a/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py new file mode 100644 index 00000000000..afeea82008a --- /dev/null +++ b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py @@ -0,0 +1,243 @@ +from datetime import datetime, timedelta +from unittest.mock import MagicMock + +import pytest +from pyspark.sql import SparkSession + +from feast.aggregation import Aggregation +from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext +from feast.infra.compute_engines.dag.model import DAGFormat +from feast.infra.compute_engines.dag.value import DAGValue +from feast.infra.compute_engines.spark.node import ( + SparkAggregationNode, + SparkDedupNode, + SparkJoinNode, + SparkTransformationNode, +) +from tests.example_repos.example_feature_repo_with_bfvs import ( + driver, + driver_hourly_stats_view, +) + + +@pytest.fixture(scope="session") +def spark_session(): + spark = ( + SparkSession.builder.appName("FeastSparkTests") + .master("local[*]") + .config("spark.sql.shuffle.partitions", "1") + .getOrCreate() + ) + + yield spark + + spark.stop() + + +def test_spark_transformation_node_executes_udf(spark_session): + # Sample Spark input + df = spark_session.createDataFrame( + [ + {"name": "John D.", "age": 30}, + {"name": "Alice G.", "age": 25}, + ] + ) + + def strip_extra_spaces(df): + from pyspark.sql.functions import col, regexp_replace + + return df.withColumn("name", regexp_replace(col("name"), "\\s+", " ")) + + # Wrap DAGValue + input_value = DAGValue(data=df, format=DAGFormat.SPARK) + + # Setup context + context = ExecutionContext( + project="test_proj", + repo_config=MagicMock(), + offline_store=MagicMock(), + online_store=MagicMock(), + entity_defs=MagicMock(), + entity_df=None, + column_info=ColumnInfo( + join_keys=["name"], + feature_cols=["age"], + ts_col="", + created_ts_col="", + ), + node_outputs={"source": input_value}, + ) + + # Create and run the node + node = SparkTransformationNode( + "transform", input_node=MagicMock(), udf=strip_extra_spaces + ) + + node.inputs[0].name = "source" + result = node.execute(context) + + # Assert output + out_df = result.data + rows = out_df.orderBy("age").collect() + assert rows[0]["name"] == "Alice G." + assert rows[1]["name"] == "John D." + + +def test_spark_aggregation_node_executes_correctly(spark_session): + # Sample input DataFrame + input_df = spark_session.createDataFrame( + [ + {"user_id": 1, "value": 10}, + {"user_id": 1, "value": 20}, + {"user_id": 2, "value": 5}, + ] + ) + + # Define Aggregation spec (e.g. COUNT on value) + agg_specs = [Aggregation(column="value", function="count")] + + # Wrap as DAGValue + input_value = DAGValue(data=input_df, format=DAGFormat.SPARK) + + # Setup context + context = ExecutionContext( + project="test_project", + repo_config=MagicMock(), + offline_store=MagicMock(), + online_store=MagicMock(), + entity_defs=[], + entity_df=None, + column_info=ColumnInfo( + join_keys=["user_id"], + feature_cols=["value"], + ts_col="", + created_ts_col="", + ), + node_outputs={"source": input_value}, + ) + + # Create and configure node + node = SparkAggregationNode( + name="agg", + input_node=MagicMock(), + aggregations=agg_specs, + group_by_keys=["user_id"], + timestamp_col="", + ) + node.inputs[0].name = "source" + + # Execute + result = node.execute(context) + result_df = result.data.orderBy("user_id").collect() + + # Validate output + assert result.format == DAGFormat.SPARK + assert result_df[0]["user_id"] == 1 + assert result_df[0]["count_value"] == 2 + assert result_df[1]["user_id"] == 2 + assert result_df[1]["count_value"] == 1 + + +def test_spark_join_node_executes_point_in_time_join(spark_session): + now = datetime.utcnow() + + # Entity DataFrame (point-in-time join targets) + entity_df = spark_session.createDataFrame( + [ + {"driver_id": 1001, "event_timestamp": now}, + {"driver_id": 1002, "event_timestamp": now}, + ] + ) + + # Feature DataFrame (raw features with timestamp) + feature_df = spark_session.createDataFrame( + [ + { + "driver_id": 1001, + "event_timestamp": now - timedelta(days=1), + "created": now - timedelta(hours=2), + "conv_rate": 0.8, + "acc_rate": 0.95, + "avg_daily_trips": 15, + }, + { + "driver_id": 1001, + "event_timestamp": now - timedelta(days=2), + "created": now - timedelta(hours=4), + "conv_rate": 0.75, + "acc_rate": 0.90, + "avg_daily_trips": 14, + }, + { + "driver_id": 1002, + "event_timestamp": now - timedelta(days=1), + "created": now - timedelta(hours=3), + "conv_rate": 0.7, + "acc_rate": 0.88, + "avg_daily_trips": 12, + }, + ] + ) + + # Wrap as DAGValues + feature_val = DAGValue(data=feature_df, format=DAGFormat.SPARK) + + # Setup FeatureView mock with batch_source metadata + feature_view = driver_hourly_stats_view + + # Set up context + context = ExecutionContext( + project="test_project", + repo_config=MagicMock(), + offline_store=MagicMock(), + online_store=MagicMock(), + entity_defs=[driver], + entity_df=entity_df, + node_outputs={ + "feature_node": feature_val, + }, + column_info=ColumnInfo( + join_keys=["driver_id"], + feature_cols=["conv_rate", "acc_rate", "avg_daily_trips"], + ts_col="event_timestamp", + created_ts_col="created", + ), + ) + + # Create the node and add input + join_node = SparkJoinNode( + name="join", + feature_node=MagicMock(name="feature_node"), + join_keys=["user_id"], + feature_view=feature_view, + spark_session=spark_session, + ) + join_node.inputs[0].name = "feature_node" # must match key in node_outputs + + # Execute the node + output = join_node.execute(context) + context.node_outputs["join"] = output + + dedup_node = SparkDedupNode( + name="dedup", + input_node=join_node, + feature_view=feature_view, + spark_session=spark_session, + ) + dedup_node.inputs[0].name = "join" # must match key in node_outputs + dedup_output = dedup_node.execute(context) + result_df = dedup_output.data.orderBy("driver_id").collect() + + # Assertions + assert output.format == DAGFormat.SPARK + assert len(result_df) == 2 + + # Validate result for driver_id = 1001 + assert result_df[0]["driver_id"] == 1001 + assert abs(result_df[0]["conv_rate"] - 0.8) < 1e-6 + assert result_df[0]["avg_daily_trips"] == 15 + + # Validate result for driver_id = 1002 + assert result_df[1]["driver_id"] == 1002 + assert abs(result_df[1]["conv_rate"] - 0.7) < 1e-6 + assert result_df[1]["avg_daily_trips"] == 12