diff --git a/protos/feast/core/FeatureView.proto b/protos/feast/core/FeatureView.proto index 481ae00403f..ff750acccc2 100644 --- a/protos/feast/core/FeatureView.proto +++ b/protos/feast/core/FeatureView.proto @@ -79,6 +79,8 @@ message FeatureViewSpec { // Whether these features should be written to the offline store bool offline = 13; + + repeated FeatureViewSpec source_views = 14; } message FeatureViewMeta { diff --git a/sdk/python/feast/batch_feature_view.py b/sdk/python/feast/batch_feature_view.py index 933696ced33..85a71f01c43 100644 --- a/sdk/python/feast/batch_feature_view.py +++ b/sdk/python/feast/batch_feature_view.py @@ -53,6 +53,7 @@ class BatchFeatureView(FeatureView): entities: List[str] ttl: Optional[timedelta] source: DataSource + sink_source: Optional[DataSource] = None schema: List[Field] entity_columns: List[Field] features: List[Field] @@ -65,7 +66,7 @@ class BatchFeatureView(FeatureView): materialization_intervals: List[Tuple[datetime, datetime]] udf: Optional[Callable[[Any], Any]] udf_string: Optional[str] - feature_transformation: Transformation + feature_transformation: Optional[Transformation] batch_engine: Optional[Field] aggregations: Optional[List[Aggregation]] @@ -74,7 +75,8 @@ def __init__( *, name: str, mode: Union[TransformationMode, str] = TransformationMode.PYTHON, - source: DataSource, + source: Union[DataSource, "BatchFeatureView", List["BatchFeatureView"]], + sink_source: Optional[DataSource] = None, entities: Optional[List[Entity]] = None, ttl: Optional[timedelta] = None, tags: Optional[Dict[str, str]] = None, @@ -83,7 +85,7 @@ def __init__( description: str = "", owner: str = "", schema: Optional[List[Field]] = None, - udf: Optional[Callable[[Any], Any]], + udf: Optional[Callable[[Any], Any]] = None, udf_string: Optional[str] = "", feature_transformation: Optional[Transformation] = None, batch_engine: Optional[Field] = None, @@ -96,7 +98,7 @@ def __init__( RuntimeWarning, ) - if ( + if isinstance(source, DataSource) and ( type(source).__name__ not in SUPPORTED_BATCH_SOURCES and source.to_proto().type != DataSourceProto.SourceType.CUSTOM_SOURCE ): @@ -124,14 +126,13 @@ def __init__( description=description, owner=owner, schema=schema, - source=source, + source=source, # type: ignore[arg-type] + sink_source=sink_source, ) - def get_feature_transformation(self) -> Transformation: + def get_feature_transformation(self) -> Optional[Transformation]: if not self.udf: - raise ValueError( - "Either a UDF or a feature transformation must be provided for BatchFeatureView" - ) + return None if self.mode in ( TransformationMode.PANDAS, TransformationMode.PYTHON, diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 3df26d76f1b..255df3db41f 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -922,7 +922,7 @@ def apply( for fv in itertools.chain( views_to_update, sfvs_to_update, odfvs_with_writes_to_update ): - if isinstance(fv, FeatureView): + if isinstance(fv, FeatureView) and fv.batch_source: data_sources_set_to_update.add(fv.batch_source) if hasattr(fv, "stream_source"): if fv.stream_source: diff --git a/sdk/python/feast/feature_view.py b/sdk/python/feast/feature_view.py index 2c2106f5a3e..16e786bced8 100644 --- a/sdk/python/feast/feature_view.py +++ b/sdk/python/feast/feature_view.py @@ -14,7 +14,7 @@ import copy import warnings from datetime import datetime, timedelta -from typing import Dict, List, Optional, Tuple, Type +from typing import Dict, List, Optional, Tuple, Type, Union from google.protobuf.duration_pb2 import Duration from google.protobuf.message import Message @@ -90,6 +90,7 @@ class FeatureView(BaseFeatureView): ttl: Optional[timedelta] batch_source: DataSource stream_source: Optional[DataSource] + source_views: Optional[List["FeatureView"]] entity_columns: List[Field] features: List[Field] online: bool @@ -103,7 +104,8 @@ def __init__( self, *, name: str, - source: DataSource, + source: Union[DataSource, "FeatureView", List["FeatureView"]], + sink_source: Optional[DataSource] = None, schema: Optional[List[Field]] = None, entities: Optional[List[Entity]] = None, ttl: Optional[timedelta] = timedelta(days=0), @@ -144,22 +146,45 @@ def __init__( self.ttl = ttl schema = schema or [] - # Initialize data sources. + # Normalize source + self.stream_source = None + self.data_source: Optional[DataSource] = None + self.source_views: List[FeatureView] = [] + + if isinstance(source, DataSource): + self.data_source = source + elif isinstance(source, FeatureView): + self.source_views = [source] + elif isinstance(source, list) and all( + isinstance(sv, FeatureView) for sv in source + ): + self.source_views = source + else: + raise TypeError( + "source must be a DataSource, a FeatureView, or a list of FeatureView." + ) + + # Set up stream, batch and derived view sources if ( - isinstance(source, PushSource) - or isinstance(source, KafkaSource) - or isinstance(source, KinesisSource) + isinstance(self.data_source, PushSource) + or isinstance(self.data_source, KafkaSource) + or isinstance(self.data_source, KinesisSource) ): - self.stream_source = source - if not source.batch_source: + # Stream source definition + self.stream_source = self.data_source + if not self.data_source.batch_source: raise ValueError( - f"A batch_source needs to be specified for stream source `{source.name}`" + f"A batch_source needs to be specified for stream source `{self.data_source.name}`" ) - else: - self.batch_source = source.batch_source + self.batch_source = self.data_source.batch_source + elif self.data_source: + # Batch source definition + self.batch_source = self.data_source else: - self.stream_source = None - self.batch_source = source + # Derived view source definition + if not sink_source: + raise ValueError("Derived FeatureView must specify `sink_source`.") + self.batch_source = sink_source # Initialize features and entity columns. features: List[Field] = [] @@ -201,17 +226,18 @@ def __init__( ) # TODO(felixwang9817): Add more robust validation of features. - cols = [field.name for field in schema] - for col in cols: - if ( - self.batch_source.field_mapping is not None - and col in self.batch_source.field_mapping.keys() - ): - raise ValueError( - f"The field {col} is mapped to {self.batch_source.field_mapping[col]} for this data source. " - f"Please either remove this field mapping or use {self.batch_source.field_mapping[col]} as the " - f"Entity or Feature name." - ) + if self.batch_source is not None: + cols = [field.name for field in schema] + for col in cols: + if ( + self.batch_source.field_mapping is not None + and col in self.batch_source.field_mapping.keys() + ): + raise ValueError( + f"The field {col} is mapped to {self.batch_source.field_mapping[col]} for this data source. " + f"Please either remove this field mapping or use {self.batch_source.field_mapping[col]} as the " + f"Entity or Feature name." + ) super().__init__( name=name, @@ -219,7 +245,7 @@ def __init__( description=description, tags=tags, owner=owner, - source=source, + source=self.batch_source, ) self.online = online self.offline = offline @@ -348,13 +374,18 @@ def to_proto(self) -> FeatureViewProto: meta = self.to_proto_meta() ttl_duration = self.get_ttl_duration() - batch_source_proto = self.batch_source.to_proto() - batch_source_proto.data_source_class_type = f"{self.batch_source.__class__.__module__}.{self.batch_source.__class__.__name__}" + batch_source_proto = None + if self.batch_source: + batch_source_proto = self.batch_source.to_proto() + batch_source_proto.data_source_class_type = f"{self.batch_source.__class__.__module__}.{self.batch_source.__class__.__name__}" stream_source_proto = None if self.stream_source: stream_source_proto = self.stream_source.to_proto() stream_source_proto.data_source_class_type = f"{self.stream_source.__class__.__module__}.{self.stream_source.__class__.__name__}" + source_view_protos = None + if self.source_views: + source_view_protos = [view.to_proto().spec for view in self.source_views] spec = FeatureViewSpecProto( name=self.name, entities=self.entities, @@ -368,6 +399,7 @@ def to_proto(self) -> FeatureViewProto: offline=self.offline, batch_source=batch_source_proto, stream_source=stream_source_proto, + source_views=source_view_protos, ) return FeatureViewProto(spec=spec, meta=meta) @@ -403,12 +435,21 @@ def from_proto(cls, feature_view_proto: FeatureViewProto): Returns: A FeatureViewProto object based on the feature view protobuf. """ - batch_source = DataSource.from_proto(feature_view_proto.spec.batch_source) + batch_source = ( + DataSource.from_proto(feature_view_proto.spec.batch_source) + if feature_view_proto.spec.HasField("batch_source") + else None + ) stream_source = ( DataSource.from_proto(feature_view_proto.spec.stream_source) if feature_view_proto.spec.HasField("stream_source") else None ) + source_views = [ + FeatureView.from_proto(FeatureViewProto(spec=view_spec, meta=None)) + for view_spec in feature_view_proto.spec.source_views + ] + feature_view = cls( name=feature_view_proto.spec.name, description=feature_view_proto.spec.description, @@ -421,7 +462,7 @@ def from_proto(cls, feature_view_proto: FeatureViewProto): if feature_view_proto.spec.ttl.ToNanoseconds() == 0 else feature_view_proto.spec.ttl.ToTimedelta() ), - source=batch_source, + source=batch_source if batch_source else source_views, ) if stream_source: feature_view.stream_source = stream_source diff --git a/sdk/python/feast/inference.py b/sdk/python/feast/inference.py index dd43d1f5bdb..f5f234b7301 100644 --- a/sdk/python/feast/inference.py +++ b/sdk/python/feast/inference.py @@ -26,6 +26,8 @@ def update_data_sources_with_inferred_event_timestamp_col( ) -> None: ERROR_MSG_PREFIX = "Unable to infer DataSource timestamp_field" for data_source in data_sources: + if data_source is None: + continue if isinstance(data_source, RequestSource): continue if isinstance(data_source, PushSource): diff --git a/sdk/python/feast/infra/compute_engines/algorithms/__init__.py b/sdk/python/feast/infra/compute_engines/algorithms/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/feast/infra/compute_engines/algorithms/topo.py b/sdk/python/feast/infra/compute_engines/algorithms/topo.py new file mode 100644 index 00000000000..8b2e7aebc13 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/algorithms/topo.py @@ -0,0 +1,43 @@ +from typing import List, Set + +from feast.infra.compute_engines.dag.node import DAGNode + + +def topo_sort(root: DAGNode) -> List[DAGNode]: + """ + Topologically sort a DAG starting from a single root node. + + Args: + root: The root DAGNode. + + Returns: + A list of DAGNodes in topological order (dependencies first). + """ + return topo_sort_multiple([root]) + + +def topo_sort_multiple(roots: List[DAGNode]) -> List[DAGNode]: + """ + Topologically sort a DAG with multiple roots. + + Args: + roots: List of root DAGNodes. + + Returns: + A list of all reachable DAGNodes in execution-safe order. + """ + visited: Set[int] = set() + ordered: List[DAGNode] = [] + + def dfs(node: DAGNode): + if id(node) in visited: + return + visited.add(id(node)) + for input_node in node.inputs: + dfs(input_node) + ordered.append(node) + + for root in roots: + dfs(root) + + return ordered diff --git a/sdk/python/feast/infra/compute_engines/base.py b/sdk/python/feast/infra/compute_engines/base.py index 6acdb8d11d6..e50494abd63 100644 --- a/sdk/python/feast/infra/compute_engines/base.py +++ b/sdk/python/feast/infra/compute_engines/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Sequence, Union +from typing import List, Sequence, Union import pyarrow as pa @@ -12,13 +12,12 @@ MaterializationTask, ) from feast.infra.common.retrieval_task import HistoricalRetrievalTask -from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext +from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.offline_stores.offline_store import OfflineStore, RetrievalJob from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.base_registry import BaseRegistry from feast.on_demand_feature_view import OnDemandFeatureView from feast.stream_feature_view import StreamFeatureView -from feast.utils import _get_column_names class ComputeEngine(ABC): @@ -124,52 +123,11 @@ def get_execution_context( if hasattr(task, "entity_df") and task.entity_df is not None: entity_df = task.entity_df - column_info = self.get_column_info(registry, task) return ExecutionContext( project=task.project, repo_config=self.repo_config, offline_store=self.offline_store, online_store=self.online_store, entity_defs=entity_defs, - column_info=column_info, entity_df=entity_df, ) - - def get_column_info( - self, - registry: BaseRegistry, - task: Union[MaterializationTask, HistoricalRetrievalTask], - ) -> ColumnInfo: - entities = [] - for entity_name in task.feature_view.entities: - entities.append(registry.get_entity(entity_name, task.project)) - - join_keys, feature_cols, ts_col, created_ts_col = _get_column_names( - task.feature_view, entities - ) - field_mapping = self.get_field_mapping(task.feature_view) - - return ColumnInfo( - join_keys=join_keys, - feature_cols=feature_cols, - ts_col=ts_col, - created_ts_col=created_ts_col, - field_mapping=field_mapping, - ) - - def get_field_mapping( - self, feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView] - ) -> Optional[dict]: - """ - Get the field mapping for a feature view. - Args: - feature_view: The feature view to get the field mapping for. - - Returns: - A dictionary mapping field names to column names. - """ - if feature_view.stream_source: - return feature_view.stream_source.field_mapping - if feature_view.batch_source: - return feature_view.batch_source.field_mapping - return None diff --git a/sdk/python/feast/infra/compute_engines/dag/context.py b/sdk/python/feast/infra/compute_engines/dag/context.py index 6b1970d25f8..46eda356223 100644 --- a/sdk/python/feast/infra/compute_engines/dag/context.py +++ b/sdk/python/feast/infra/compute_engines/dag/context.py @@ -82,15 +82,12 @@ class ExecutionContext: node_outputs: Internal cache of DAGValue outputs keyed by DAGNode name. Automatically populated during ExecutionPlan execution to avoid redundant computation. Used by downstream nodes to access their input data. - - field_mapping: A mapping of field names to their corresponding column names in the """ project: str repo_config: RepoConfig offline_store: OfflineStore online_store: OnlineStore - column_info: ColumnInfo entity_defs: List[Entity] entity_df: Union[pd.DataFrame, None] = None node_outputs: Dict[str, DAGValue] = field(default_factory=dict) diff --git a/sdk/python/feast/infra/compute_engines/dag/node.py b/sdk/python/feast/infra/compute_engines/dag/node.py index 033ae8f1780..9fb520e7c13 100644 --- a/sdk/python/feast/infra/compute_engines/dag/node.py +++ b/sdk/python/feast/infra/compute_engines/dag/node.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List +from typing import List, Optional from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.dag.value import DAGValue @@ -10,10 +10,13 @@ class DAGNode(ABC): inputs: List["DAGNode"] outputs: List["DAGNode"] - def __init__(self, name: str): + def __init__(self, name: str, inputs: Optional[List["DAGNode"]] = None): self.name = name - self.inputs = [] - self.outputs = [] + self.inputs: List["DAGNode"] = [] + self.outputs: List["DAGNode"] = [] + + for node in inputs or []: + self.add_input(node) def add_input(self, node: "DAGNode"): if node in self.inputs: diff --git a/sdk/python/feast/infra/compute_engines/dag/plan.py b/sdk/python/feast/infra/compute_engines/dag/plan.py index 130a894bda8..31db551e635 100644 --- a/sdk/python/feast/infra/compute_engines/dag/plan.py +++ b/sdk/python/feast/infra/compute_engines/dag/plan.py @@ -31,7 +31,8 @@ class ExecutionPlan: Example: DAG: - ReadNode -> AggregateNode -> JoinNode -> TransformNode -> WriteNode + ReadNode -> TransformNode -> AggregateNode -> -> WriteNode + -> JoinNode -> Execution proceeds step by step, passing intermediate DAGValues through the plan while respecting node dependencies and formats. @@ -47,10 +48,6 @@ def execute(self, context: ExecutionContext) -> DAGValue: context.node_outputs = {} for node in self.nodes: - for input_node in node.inputs: - if input_node.name not in context.node_outputs: - context.node_outputs[input_node.name] = input_node.execute(context) - output = node.execute(context) context.node_outputs[node.name] = output @@ -62,3 +59,17 @@ def to_sql(self, context: ExecutionContext) -> str: This is a placeholder and should be implemented in subclasses. """ raise NotImplementedError("SQL generation is not implemented yet.") + + def to_dag(self) -> str: + """ + Render the DAG as a multiline string with full node expansion (no visited shortcut). + """ + + def walk(node: DAGNode, indent: int = 0) -> List[str]: + prefix = " " * indent + lines = [f"{prefix}- {node.name}"] + for input_node in node.inputs: + lines.extend(walk(input_node, indent + 1)) + return lines + + return "\n".join(walk(self.nodes[-1])) diff --git a/sdk/python/feast/infra/compute_engines/feature_builder.py b/sdk/python/feast/infra/compute_engines/feature_builder.py index 9d4e4466499..26f3703c794 100644 --- a/sdk/python/feast/infra/compute_engines/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/feature_builder.py @@ -1,10 +1,18 @@ from abc import ABC, abstractmethod -from typing import Union +from typing import Dict, List, Optional, Union +from feast import BatchFeatureView, FeatureView, StreamFeatureView from feast.infra.common.materialization_job import MaterializationTask from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.algorithms.topo import topo_sort +from feast.infra.compute_engines.dag.context import ColumnInfo from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.plan import ExecutionPlan +from feast.infra.compute_engines.feature_resolver import ( + FeatureResolver, +) +from feast.infra.registry.base_registry import BaseRegistry +from feast.utils import _get_column_names class FeatureBuilder(ABC): @@ -15,84 +23,158 @@ class FeatureBuilder(ABC): def __init__( self, + registry: BaseRegistry, + feature_view, task: Union[MaterializationTask, HistoricalRetrievalTask], ): - self.feature_view = task.feature_view + self.registry = registry + self.feature_view = feature_view self.task = task - self.nodes: list[DAGNode] = [] + self.nodes: List[DAGNode] = [] + self.feature_resolver = FeatureResolver() + self.dag_root = self.feature_resolver.resolve(self.feature_view) @abstractmethod - def build_source_node(self): + def build_source_node(self, view): raise NotImplementedError @abstractmethod - def build_aggregation_node(self, input_node): + def build_aggregation_node(self, view, input_node): raise NotImplementedError @abstractmethod - def build_join_node(self, input_node): + def build_join_node(self, view, input_nodes): raise NotImplementedError @abstractmethod - def build_filter_node(self, input_node): + def build_filter_node(self, view, input_node): raise NotImplementedError @abstractmethod - def build_dedup_node(self, input_node): + def build_dedup_node(self, view, input_node): raise NotImplementedError @abstractmethod - def build_transformation_node(self, input_node): + def build_transformation_node(self, view, input_nodes): raise NotImplementedError @abstractmethod - def build_output_nodes(self, input_node): + def build_output_nodes(self, view, final_node): raise NotImplementedError @abstractmethod - def build_validation_node(self, input_node): - raise - - def _should_aggregate(self): - return ( - hasattr(self.feature_view, "aggregations") - and self.feature_view.aggregations is not None - and len(self.feature_view.aggregations) > 0 - ) + def build_validation_node(self, view, input_node): + raise NotImplementedError - def _should_transform(self): - return ( - hasattr(self.feature_view, "feature_transformation") - and self.feature_view.feature_transformation - ) + def _should_aggregate(self, view): + return bool(getattr(view, "aggregations", [])) - def _should_validate(self): - return getattr(self.feature_view, "enable_validation", False) + def _should_transform(self, view): + return bool(getattr(view, "feature_transformation", None)) - def _should_dedupe(self, task): - return isinstance(task, HistoricalRetrievalTask) or task.only_latest + def _should_validate(self, view): + return getattr(view, "enable_validation", False) - def build(self) -> ExecutionPlan: - last_node = self.build_source_node() + def _should_dedupe(self, view): + return isinstance(self.task, HistoricalRetrievalTask) or self.task.only_latest + + def _build(self, view, input_nodes: Optional[List[DAGNode]]) -> DAGNode: + # Step 1: build source node + if view.data_source: + last_node = self.build_source_node(view) + + if self._should_transform(view): + # Transform applied to the source data + last_node = self.build_transformation_node(view, [last_node]) + + # If there are input nodes, transform or join them + elif input_nodes: + # User-defined transform handles the merging of input views + if self._should_transform(view): + last_node = self.build_transformation_node(view, input_nodes) + # Default join + else: + last_node = self.build_join_node(view, input_nodes) + else: + raise ValueError(f"FeatureView {view.name} has no valid source or inputs") - # Join entity_df with source if needed - last_node = self.build_join_node(last_node) + # Step 2: filter + last_node = self.build_filter_node(view, last_node) - # PIT filter, TTL, and user-defined filter - last_node = self.build_filter_node(last_node) + # Step 3: aggregate or dedupe + if self._should_aggregate(view): + last_node = self.build_aggregation_node(view, last_node) + elif self._should_dedupe(view): + last_node = self.build_dedup_node(view, last_node) - if self._should_aggregate(): - last_node = self.build_aggregation_node(last_node) + # Step 4: validate + if self._should_validate(view): + last_node = self.build_validation_node(view, last_node) - # Dedupe only if not aggregated - elif self._should_dedupe(self.task): - last_node = self.build_dedup_node(last_node) + return last_node - if self._should_transform(): - last_node = self.build_transformation_node(last_node) + def build(self) -> ExecutionPlan: + # Step 1: Topo sort the FeatureViewNode DAG (Logical DAG) + logical_nodes = self.feature_resolver.topo_sort(self.dag_root) + + # Step 2: For each FeatureView, build its corresponding execution DAGNode + view_to_node: Dict[str, DAGNode] = {} + + for node in logical_nodes: + view = node.view + parent_dag_nodes = [ + view_to_node[parent.view.name] + for parent in node.inputs + if parent.view.name in view_to_node + ] + dag_node = self._build(view, parent_dag_nodes) + view_to_node[view.name] = dag_node + + # Step 3: Build output node + final_node = self.build_output_nodes( + self.feature_view, view_to_node[self.feature_view.name] + ) + + # Step 4: Topo sort the final DAG from the output node (Physical DAG) + sorted_nodes = topo_sort(final_node) + + # Step 5: Return sorted execution plan + return ExecutionPlan(sorted_nodes) - if self._should_validate(): - last_node = self.build_validation_node(last_node) + def get_column_info( + self, + view: Union[BatchFeatureView, StreamFeatureView, FeatureView], + ) -> ColumnInfo: + entities = [] + for entity_name in view.entities: + entities.append(self.registry.get_entity(entity_name, self.task.project)) + + join_keys, feature_cols, ts_col, created_ts_col = _get_column_names( + view, entities + ) + field_mapping = self.get_field_mapping(self.task.feature_view) + + return ColumnInfo( + join_keys=join_keys, + feature_cols=feature_cols, + ts_col=ts_col, + created_ts_col=created_ts_col, + field_mapping=field_mapping, + ) - self.build_output_nodes(last_node) - return ExecutionPlan(self.nodes) + def get_field_mapping( + self, feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView] + ) -> Optional[dict]: + """ + Get the field mapping for a feature view. + Args: + feature_view: The feature view to get the field mapping for. + + Returns: + A dictionary mapping field names to column names. + """ + if feature_view.stream_source: + return feature_view.stream_source.field_mapping + if feature_view.batch_source: + return feature_view.batch_source.field_mapping + return None diff --git a/sdk/python/feast/infra/compute_engines/feature_resolver.py b/sdk/python/feast/infra/compute_engines/feature_resolver.py new file mode 100644 index 00000000000..ae2f505c1d7 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/feature_resolver.py @@ -0,0 +1,95 @@ +from typing import List, Optional, Set + +from feast.feature_view import FeatureView +from feast.infra.compute_engines.algorithms.topo import topo_sort +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.node import DAGNode +from feast.infra.compute_engines.dag.value import DAGValue + + +class FeatureViewNode(DAGNode): + """ + Logical representation of a node in the FeatureView dependency DAG. + """ + + def __init__( + self, view: FeatureView, inputs: Optional[List["FeatureViewNode"]] = None + ): + super().__init__(name=view.name) + self.view: FeatureView = view + self.inputs: List["FeatureViewNode"] = inputs or [] # type: ignore + + def execute(self, context: ExecutionContext) -> DAGValue: + raise NotImplementedError( + f"FeatureViewNode '{self.name}' does not implement execute method." + ) + + +class FeatureResolver: + """ + Resolves FeatureViews into a dependency graph (DAG) of FeatureViewNode objects. + This graph represents the logical dependencies between FeatureViews, allowing + for ordered execution and cycle detection. + """ + + def __init__(self): + self._visited: Set[str] = set() + self._resolution_path: List[str] = [] + self._node_cache: dict[str, FeatureViewNode] = {} + + def resolve(self, feature_view: FeatureView) -> FeatureViewNode: + """ + Entry point for resolving a FeatureView into a DAG node. + + Args: + feature_view: The root FeatureView to build the dependency graph from. + + Returns: + A FeatureViewNode representing the root of the logical dependency DAG. + """ + return self._walk(feature_view) + + def _walk(self, view: FeatureView): + """ + Recursive traversal of the FeatureView graph. + + If `source_view` is set on the FeatureView, a parent node is created and added. + Cycles are detected using the visited set. + + Args: + view: The FeatureView to process. + """ + if view.name in self._resolution_path: + cycle = " โ†’ ".join(self._resolution_path + [view.name]) + raise ValueError(f"Cycle detected in FeatureView DAG: {cycle}") + + if view.name in self._node_cache: + return self._node_cache[view.name] + + node = FeatureViewNode(view) + self._node_cache[view.name] = node + + self._resolution_path.append(view.name) + if view.source_views: + for upstream_view in view.source_views: + input_node = self._walk(upstream_view) + node.inputs.append(input_node) + self._resolution_path.pop() + + return node + + def topo_sort(self, root: FeatureViewNode) -> List[FeatureViewNode]: + return topo_sort(root) # type: ignore + + def debug_dag(self, node: FeatureViewNode, depth=0): + """ + Prints the FeatureView dependency DAG for debugging. + + Args: + node: The root node to print from. + depth: Internal argument used for recursive indentation. + """ + indent = " " * depth + print(f"{indent}- {node.view.name}") + for input_node in node.inputs: + self.debug_dag(input_node, depth + 1) # type: ignore diff --git a/sdk/python/feast/infra/compute_engines/local/compute.py b/sdk/python/feast/infra/compute_engines/local/compute.py index 0b99a58c304..341b20dee02 100644 --- a/sdk/python/feast/infra/compute_engines/local/compute.py +++ b/sdk/python/feast/infra/compute_engines/local/compute.py @@ -68,7 +68,7 @@ def _materialize_one( backend = self._get_backend(context) try: - builder = LocalFeatureBuilder(task, backend=backend) + builder = LocalFeatureBuilder(registry, task, backend=backend) plan = builder.build() plan.execute(context) return LocalMaterializationJob( @@ -90,7 +90,7 @@ def get_historical_features( backend = self._get_backend(context) try: - builder = LocalFeatureBuilder(task=task, backend=backend) + builder = LocalFeatureBuilder(registry, task=task, backend=backend) plan = builder.build() return LocalRetrievalJob( plan=plan, diff --git a/sdk/python/feast/infra/compute_engines/local/feature_builder.py b/sdk/python/feast/infra/compute_engines/local/feature_builder.py index e3e29099360..8cecaa431df 100644 --- a/sdk/python/feast/infra/compute_engines/local/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/local/feature_builder.py @@ -14,86 +14,89 @@ LocalTransformationNode, LocalValidationNode, ) +from feast.infra.registry.base_registry import BaseRegistry class LocalFeatureBuilder(FeatureBuilder): def __init__( self, + registry: BaseRegistry, task: Union[MaterializationTask, HistoricalRetrievalTask], backend: DataFrameBackend, ): - super().__init__(task) + super().__init__(registry, task.feature_view, task) self.backend = backend - def build_source_node(self): - source = self.feature_view.batch_source + def build_source_node(self, view): start_time = self.task.start_time end_time = self.task.end_time - node = LocalSourceReadNode("source", source, start_time, end_time) + column_info = self.get_column_info(view) + source = view.source + node = LocalSourceReadNode("source", source, column_info, start_time, end_time) self.nodes.append(node) return node - def build_join_node(self, input_node): - node = LocalJoinNode("join", self.backend) - node.add_input(input_node) + def build_join_node(self, view, input_nodes): + column_info = self.get_column_info(view) + node = LocalJoinNode("join", column_info, self.backend, inputs=input_nodes) self.nodes.append(node) return node - def build_filter_node(self, input_node): - filter_expr = None - if hasattr(self.feature_view, "filter"): - filter_expr = self.feature_view.filter - ttl = self.feature_view.ttl - node = LocalFilterNode("filter", self.backend, filter_expr, ttl) - node.add_input(input_node) + def build_filter_node(self, view, input_node): + filter_expr = getattr(view, "filter", None) + ttl = getattr(view, "ttl", None) + column_info = self.get_column_info(view) + node = LocalFilterNode( + "filter", column_info, self.backend, filter_expr, ttl, inputs=[input_node] + ) self.nodes.append(node) return node - @staticmethod - def _get_aggregate_operations(agg_specs): - agg_ops = {} - for agg in agg_specs: - if agg.time_window is not None: - raise ValueError( - "Time window aggregation is not supported in local compute engine. Please use a different compute " - "engine." - ) - alias = f"{agg.function}_{agg.column}" - agg_ops[alias] = (agg.function, agg.column) - return agg_ops - - def build_aggregation_node(self, input_node): - agg_specs = self.feature_view.aggregations + def build_aggregation_node(self, view, input_node): + agg_specs = view.aggregations agg_ops = self._get_aggregate_operations(agg_specs) - group_by_keys = self.feature_view.entities - node = LocalAggregationNode("agg", self.backend, group_by_keys, agg_ops) - node.add_input(input_node) + group_by_keys = view.entities + node = LocalAggregationNode( + "agg", self.backend, group_by_keys, agg_ops, inputs=[input_node] + ) self.nodes.append(node) return node - def build_dedup_node(self, input_node): - node = LocalDedupNode("dedup", self.backend) - node.add_input(input_node) + def build_dedup_node(self, view, input_node): + column_info = self.get_column_info(view) + node = LocalDedupNode("dedup", column_info, self.backend, inputs=[input_node]) self.nodes.append(node) return node - def build_transformation_node(self, input_node): + def build_transformation_node(self, view, input_nodes): + transform_config = view.feature_transformation node = LocalTransformationNode( - "transform", self.feature_view.feature_transformation, self.backend + "transform", transform_config, self.backend, inputs=input_nodes ) - node.add_input(input_node) self.nodes.append(node) return node - def build_validation_node(self, input_node): + def build_validation_node(self, view, input_node): + validation_config = view.validation_config node = LocalValidationNode( - "validate", self.feature_view.validation_config, self.backend + "validate", validation_config, self.backend, inputs=[input_node] ) - node.add_input(input_node) self.nodes.append(node) return node - def build_output_nodes(self, input_node): - node = LocalOutputNode("output", self.feature_view) - node.add_input(input_node) + def build_output_nodes(self, view, input_node): + node = LocalOutputNode("output", self.dag_root.view, inputs=[input_node]) self.nodes.append(node) + return node + + @staticmethod + def _get_aggregate_operations(agg_specs): + agg_ops = {} + for agg in agg_specs: + if agg.time_window is not None: + raise ValueError( + "Time window aggregation is not supported in the local compute engine." + ) + alias = f"{agg.function}_{agg.column}" + agg_ops[alias] = (agg.function, agg.column) + return agg_ops diff --git a/sdk/python/feast/infra/compute_engines/local/nodes.py b/sdk/python/feast/infra/compute_engines/local/nodes.py index a8c4405dd06..870a098261d 100644 --- a/sdk/python/feast/infra/compute_engines/local/nodes.py +++ b/sdk/python/feast/infra/compute_engines/local/nodes.py @@ -1,11 +1,13 @@ from datetime import datetime, timedelta -from typing import Optional, Union +from typing import List, Optional, Union import pyarrow as pa from feast import BatchFeatureView, StreamFeatureView from feast.data_source import DataSource -from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext +from feast.infra.compute_engines.dag.model import DAGFormat +from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue from feast.infra.compute_engines.local.backends.base import DataFrameBackend from feast.infra.compute_engines.local.local_node import LocalNode @@ -25,11 +27,13 @@ def __init__( self, name: str, source: DataSource, + column_info: ColumnInfo, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, ): super().__init__(name) self.source = source + self.column_info = column_info self.start_time = start_time self.end_time = end_time @@ -39,62 +43,84 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue: context=context, start_time=self.start_time, end_time=self.end_time, + column_info=self.column_info, ) arrow_table = retrieval_job.to_arrow() - field_mapping = context.column_info.field_mapping - if field_mapping: + if self.column_info.field_mapping: arrow_table = arrow_table.rename_columns( - [field_mapping.get(col, col) for col in arrow_table.column_names] + [ + self.column_info.field_mapping.get(col, col) + for col in arrow_table.column_names + ] ) return ArrowTableValue(data=arrow_table) class LocalJoinNode(LocalNode): - def __init__(self, name: str, backend: DataFrameBackend): - super().__init__(name) + def __init__( + self, + name: str, + column_info: ColumnInfo, + backend: DataFrameBackend, + inputs: Optional[List["DAGNode"]] = None, + how: str = "inner", + ): + super().__init__(name, inputs or []) + self.column_info = column_info self.backend = backend + self.how = how def execute(self, context: ExecutionContext) -> ArrowTableValue: - feature_table = self.get_single_table(context).data - - if context.entity_df is None: - output = ArrowTableValue(feature_table) - context.node_outputs[self.name] = output - return output + input_values = self.get_input_values(context) + for val in input_values: + val.assert_format(DAGFormat.ARROW) + + # Convert all upstream ArrowTables to backend DataFrames + joined_df = self.backend.from_arrow(input_values[0].data) + for val in input_values[1:]: + next_df = self.backend.from_arrow(val.data) + joined_df = self.backend.join( + joined_df, + next_df, + on=self.column_info.join_keys, + how=self.how, + ) - entity_table = pa.Table.from_pandas(context.entity_df) - feature_df = self.backend.from_arrow(feature_table) - entity_df = self.backend.from_arrow(entity_table) + # If entity_df is provided, join it in last + if context.entity_df is not None: + entity_df = self.backend.from_arrow(pa.Table.from_pandas(context.entity_df)) - entity_schema = dict(zip(entity_df.columns, entity_df.dtypes)) - entity_df_event_timestamp_col = infer_event_timestamp_from_entity_df( - entity_schema - ) + entity_schema = dict(zip(entity_df.columns, entity_df.dtypes)) + entity_ts_col = infer_event_timestamp_from_entity_df(entity_schema) - column_info = context.column_info + if entity_ts_col != ENTITY_TS_ALIAS: + entity_df = self.backend.rename_columns( + entity_df, {entity_ts_col: ENTITY_TS_ALIAS} + ) - entity_df = self.backend.rename_columns( - entity_df, {entity_df_event_timestamp_col: ENTITY_TS_ALIAS} - ) + joined_df = self.backend.join( + entity_df, + joined_df, + on=self.column_info.join_keys, + how="left", + ) - joined_df = self.backend.join( - feature_df, entity_df, on=column_info.join_keys, how="left" - ) result = self.backend.to_arrow(joined_df) - output = ArrowTableValue(result) - context.node_outputs[self.name] = output - return output + return ArrowTableValue(result) class LocalFilterNode(LocalNode): def __init__( self, name: str, + column_info: ColumnInfo, backend: DataFrameBackend, filter_expr: Optional[str] = None, ttl: Optional[timedelta] = None, + inputs=None, ): - super().__init__(name) + super().__init__(name, inputs=inputs) + self.column_info = column_info self.backend = backend self.filter_expr = filter_expr self.ttl = ttl @@ -103,7 +129,7 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue: input_table = self.get_single_table(context).data df = self.backend.from_arrow(input_table) - timestamp_column = context.column_info.timestamp_column + timestamp_column = self.column_info.timestamp_column if ENTITY_TS_ALIAS in self.backend.columns(df): # filter where feature.ts <= entity.event_timestamp @@ -128,9 +154,14 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue: class LocalAggregationNode(LocalNode): def __init__( - self, name: str, backend: DataFrameBackend, group_keys: list[str], agg_ops: dict + self, + name: str, + backend: DataFrameBackend, + group_keys: list[str], + agg_ops: dict, + inputs=None, ): - super().__init__(name) + super().__init__(name, inputs=inputs) self.backend = backend self.group_keys = group_keys self.agg_ops = agg_ops @@ -141,13 +172,15 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue: grouped_df = self.backend.groupby_agg(df, self.group_keys, self.agg_ops) result = self.backend.to_arrow(grouped_df) output = ArrowTableValue(result) - context.node_outputs[self.name] = output return output class LocalDedupNode(LocalNode): - def __init__(self, name: str, backend: DataFrameBackend): - super().__init__(name) + def __init__( + self, name: str, column_info: ColumnInfo, backend: DataFrameBackend, inputs=None + ): + super().__init__(name, inputs=inputs) + self.column_info = column_info self.backend = backend def execute(self, context: ExecutionContext) -> ArrowTableValue: @@ -155,17 +188,16 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue: df = self.backend.from_arrow(input_table) # Extract join_keys, timestamp, and created_ts from context - column_info = context.column_info # Dedup strategy: sort and drop_duplicates - dedup_keys = context.column_info.join_keys + dedup_keys = self.column_info.join_keys if dedup_keys: - sort_keys = [column_info.timestamp_column] + sort_keys = [self.column_info.timestamp_column] if ( - column_info.created_timestamp_column - and column_info.created_timestamp_column in df.columns + self.column_info.created_timestamp_column + and self.column_info.created_timestamp_column in df.columns ): - sort_keys.append(column_info.created_timestamp_column) + sort_keys.append(self.column_info.created_timestamp_column) df = self.backend.drop_duplicates( df, keys=dedup_keys, sort_by=sort_keys, ascending=False @@ -177,8 +209,10 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue: class LocalTransformationNode(LocalNode): - def __init__(self, name: str, transformation_fn, backend): - super().__init__(name) + def __init__( + self, name: str, transformation_fn, backend: DataFrameBackend, inputs=None + ): + super().__init__(name, inputs=inputs) self.transformation_fn = transformation_fn self.backend = backend @@ -193,8 +227,10 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue: class LocalValidationNode(LocalNode): - def __init__(self, name: str, validation_config, backend): - super().__init__(name) + def __init__( + self, name: str, validation_config, backend: DataFrameBackend, inputs=None + ): + super().__init__(name, inputs=inputs) self.validation_config = validation_config self.backend = backend @@ -212,9 +248,12 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue: class LocalOutputNode(LocalNode): def __init__( - self, name: str, feature_view: Union[BatchFeatureView, StreamFeatureView] + self, + name: str, + feature_view: Union[BatchFeatureView, StreamFeatureView], + inputs=None, ): - super().__init__(name) + super().__init__(name, inputs=inputs) self.feature_view = feature_view def execute(self, context: ExecutionContext) -> ArrowTableValue: diff --git a/sdk/python/feast/infra/compute_engines/spark/compute.py b/sdk/python/feast/infra/compute_engines/spark/compute.py index 618a3b780f6..59a271a926e 100644 --- a/sdk/python/feast/infra/compute_engines/spark/compute.py +++ b/sdk/python/feast/infra/compute_engines/spark/compute.py @@ -116,6 +116,7 @@ def _materialize_one( try: # โœ… 2. Construct Feature Builder and run it builder = SparkFeatureBuilder( + registry=registry, spark_session=self.spark_session, task=task, ) @@ -211,6 +212,7 @@ def get_historical_features( try: # โœ… 2. Construct Feature Builder and run it builder = SparkFeatureBuilder( + registry=registry, spark_session=self.spark_session, task=task, ) diff --git a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py index a3059105950..e042bb000dc 100644 --- a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py @@ -14,69 +14,97 @@ SparkTransformationNode, SparkWriteNode, ) +from feast.infra.registry.base_registry import BaseRegistry class SparkFeatureBuilder(FeatureBuilder): def __init__( self, + registry: BaseRegistry, spark_session: SparkSession, task: Union[MaterializationTask, HistoricalRetrievalTask], ): - super().__init__(task) + super().__init__(registry, task.feature_view, task) self.spark_session = spark_session - def build_source_node(self): - source = self.feature_view.batch_source + def build_source_node(self, view): start_time = self.task.start_time end_time = self.task.end_time - node = SparkReadNode("source", source, self.spark_session, start_time, end_time) + source = view.batch_source + column_info = self.get_column_info(view) + node = SparkReadNode( + f"{view.name}:source", + source, + column_info, + self.spark_session, + start_time, + end_time, + ) self.nodes.append(node) return node - def build_aggregation_node(self, input_node): - agg_specs = self.feature_view.aggregations - group_by_keys = self.feature_view.entities - timestamp_col = self.feature_view.batch_source.timestamp_field - node = SparkAggregationNode("agg", agg_specs, group_by_keys, timestamp_col) - node.add_input(input_node) + def build_aggregation_node(self, view, input_node): + agg_specs = view.aggregations + group_by_keys = view.entities + timestamp_col = view.batch_source.timestamp_field + node = SparkAggregationNode( + f"{view.name}:agg", + agg_specs, + group_by_keys, + timestamp_col, + inputs=[input_node], + ) self.nodes.append(node) return node - def build_join_node(self, input_node): - node = SparkJoinNode("join", self.spark_session) - node.add_input(input_node) + def build_join_node(self, view, input_nodes): + column_info = self.get_column_info(view) + node = SparkJoinNode( + name=f"{view.name}:join", + column_info=column_info, + spark_session=self.spark_session, + inputs=input_nodes, + how="left", # You can make this configurable later + ) self.nodes.append(node) return node - def build_filter_node(self, input_node): - filter_expr = None - if hasattr(self.feature_view, "filter"): - filter_expr = self.feature_view.filter - ttl = self.feature_view.ttl - node = SparkFilterNode("filter", self.spark_session, ttl, filter_expr) - node.add_input(input_node) + def build_filter_node(self, view, input_node): + filter_expr = getattr(view, "filter", None) + ttl = getattr(view, "ttl", None) + column_info = self.get_column_info(view) + node = SparkFilterNode( + f"{view.name}:filter", + column_info, + self.spark_session, + ttl, + filter_expr, + inputs=[input_node], + ) self.nodes.append(node) return node - def build_dedup_node(self, input_node): - node = SparkDedupNode("dedup", self.spark_session) - node.add_input(input_node) + def build_dedup_node(self, view, input_node): + column_info = self.get_column_info(view) + node = SparkDedupNode( + f"{view.name}:dedup", column_info, self.spark_session, inputs=[input_node] + ) self.nodes.append(node) return node - def build_transformation_node(self, input_node): - udf_name = self.feature_view.feature_transformation.name - udf = self.feature_view.feature_transformation.udf - node = SparkTransformationNode(udf_name, udf) - node.add_input(input_node) + def build_transformation_node(self, view, input_nodes): + udf_name = view.feature_transformation.name + udf = view.feature_transformation.udf + node = SparkTransformationNode(udf_name, udf, inputs=input_nodes) self.nodes.append(node) return node - def build_output_nodes(self, input_node): - node = SparkWriteNode("output", self.feature_view) - node.add_input(input_node) + def build_output_nodes(self, view, input_node): + node = SparkWriteNode( + f"{view.name}:output", self.dag_root.view, inputs=[input_node] + ) self.nodes.append(node) return node - def build_validation_node(self, input_node): + def build_validation_node(self, view, input_node): pass diff --git a/sdk/python/feast/infra/compute_engines/spark/nodes.py b/sdk/python/feast/infra/compute_engines/spark/nodes.py index 1ab454daa52..7c2b0bd7916 100644 --- a/sdk/python/feast/infra/compute_engines/spark/nodes.py +++ b/sdk/python/feast/infra/compute_engines/spark/nodes.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from typing import List, Optional, Union, cast +from typing import Callable, List, Optional, Union, cast from pyspark.sql import DataFrame, SparkSession, Window from pyspark.sql import functions as F @@ -8,7 +8,7 @@ from feast.aggregation import Aggregation from feast.data_source import DataSource from feast.infra.common.serde import SerializedArtifacts -from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext from feast.infra.compute_engines.dag.model import DAGFormat from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.value import DAGValue @@ -56,20 +56,22 @@ def __init__( self, name: str, source: DataSource, + column_info: ColumnInfo, spark_session: SparkSession, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, ): super().__init__(name) self.source = source + self.column_info = column_info self.spark_session = spark_session self.start_time = start_time self.end_time = end_time def execute(self, context: ExecutionContext) -> DAGValue: - column_info = context.column_info retrieval_job = create_offline_store_retrieval_job( data_source=self.source, + column_info=self.column_info, context=context, start_time=self.start_time, end_time=self.end_time, @@ -84,8 +86,8 @@ def execute(self, context: ExecutionContext) -> DAGValue: format=DAGFormat.SPARK, metadata={ "source": "feature_view_batch_source", - "timestamp_field": column_info.timestamp_column, - "created_timestamp_column": column_info.created_timestamp_column, + "timestamp_field": self.column_info.timestamp_column, + "created_timestamp_column": self.column_info.created_timestamp_column, "start_date": self.start_time, "end_date": self.end_time, }, @@ -99,8 +101,9 @@ def __init__( aggregations: List[Aggregation], group_by_keys: List[str], timestamp_col: str, + inputs=None, ): - super().__init__(name) + super().__init__(name, inputs=inputs) self.aggregations = aggregations self.group_by_keys = group_by_keys self.timestamp_col = timestamp_col @@ -148,41 +151,63 @@ class SparkJoinNode(DAGNode): def __init__( self, name: str, + column_info: ColumnInfo, spark_session: SparkSession, + inputs: Optional[List[DAGNode]] = None, + how: str = "inner", ): - super().__init__(name) + super().__init__(name, inputs=inputs or []) + self.column_info = column_info self.spark_session = spark_session + self.how = how def execute(self, context: ExecutionContext) -> DAGValue: - feature_value = self.get_single_input_value(context) - feature_value.assert_format(DAGFormat.SPARK) - feature_df: DataFrame = feature_value.data + input_values = self.get_input_values(context) + for val in input_values: + val.assert_format(DAGFormat.SPARK) + + # Join all input DataFrames on join_keys + joined_df = None + for i, dag_value in enumerate(input_values): + df = dag_value.data + + # Use original FeatureView name if available + fv_name = self.inputs[i].name.split(":")[0] + prefix = fv_name + "__" + + # Skip renaming join keys to preserve join compatibility + renamed_cols = [ + F.col(c).alias(f"{prefix}{c}") + if c not in self.column_info.join_keys + else F.col(c) + for c in df.columns + ] + df = df.select(*renamed_cols) + if joined_df is None: + joined_df = df + else: + joined_df = joined_df.join( + df, on=self.column_info.join_keys, how=self.how + ) + # If entity_df is provided, join it in last entity_df = context.entity_df - if entity_df is None: - return DAGValue( - data=feature_df, - format=DAGFormat.SPARK, - metadata={"joined_on": None}, + if entity_df is not None: + entity_df = rename_entity_ts_column( + spark_session=self.spark_session, + entity_df=entity_df, ) + if joined_df is None: + raise RuntimeError("No input features available to join with entity_df") - # Get timestamp fields from feature view - column_info = context.column_info - - # Rename entity_df event_timestamp_col to match feature_df - entity_df = rename_entity_ts_column( - spark_session=self.spark_session, - entity_df=entity_df, - ) - - # Perform left join on entity df - # TODO: give a config option to use other join types - joined = feature_df.join(entity_df, on=column_info.join_keys, how="left") + joined_df = entity_df.join( + joined_df, on=self.column_info.join_keys, how="left" + ) return DAGValue( - data=joined, + data=joined_df, format=DAGFormat.SPARK, - metadata={"joined_on": column_info.join_keys}, + metadata={"joined_on": self.column_info.join_keys, "join_type": self.how}, ) @@ -190,11 +215,14 @@ class SparkFilterNode(DAGNode): def __init__( self, name: str, + column_info: ColumnInfo, spark_session: SparkSession, ttl: Optional[timedelta] = None, filter_condition: Optional[str] = None, + inputs=None, ): - super().__init__(name) + super().__init__(name, inputs=inputs) + self.column_info = column_info self.spark_session = spark_session self.ttl = ttl self.filter_condition = filter_condition @@ -205,7 +233,7 @@ def execute(self, context: ExecutionContext) -> DAGValue: input_df: DataFrame = input_value.data # Get timestamp fields from feature view - timestamp_column = context.column_info.timestamp_column + timestamp_column = self.column_info.timestamp_column # Optional filter: feature.ts <= entity.event_timestamp filtered_df = input_df @@ -237,9 +265,12 @@ class SparkDedupNode(DAGNode): def __init__( self, name: str, + column_info: ColumnInfo, spark_session: SparkSession, + inputs=None, ): - super().__init__(name) + super().__init__(name, inputs=inputs) + self.column_info = column_info self.spark_session = spark_session def execute(self, context: ExecutionContext) -> DAGValue: @@ -247,17 +278,14 @@ def execute(self, context: ExecutionContext) -> DAGValue: input_value.assert_format(DAGFormat.SPARK) input_df: DataFrame = input_value.data - # Get timestamp fields from feature view - colmun_info = context.column_info - # Dedup based on join keys and event timestamp column # Dedup with row_number - partition_cols = context.column_info.join_keys + partition_cols = self.column_info.join_keys deduped_df = input_df if partition_cols: - ordering = [F.col(colmun_info.timestamp_column).desc()] - if colmun_info.created_timestamp_column: - ordering.append(F.col(colmun_info.created_timestamp_column).desc()) + ordering = [F.col(self.column_info.timestamp_column).desc()] + if self.column_info.created_timestamp_column: + ordering.append(F.col(self.column_info.created_timestamp_column).desc()) window = Window.partitionBy(*partition_cols).orderBy(*ordering) deduped_df = ( @@ -278,8 +306,9 @@ def __init__( self, name: str, feature_view: Union[BatchFeatureView, StreamFeatureView], + inputs=None, ): - super().__init__(name) + super().__init__(name, inputs=inputs) self.feature_view = feature_view def execute(self, context: ExecutionContext) -> DAGValue: @@ -324,15 +353,18 @@ def execute(self, context: ExecutionContext) -> DAGValue: class SparkTransformationNode(DAGNode): - def __init__(self, name: str, udf): - super().__init__(name) + def __init__(self, name: str, udf: Callable, inputs: List[DAGNode]): + super().__init__(name, inputs) self.udf = udf def execute(self, context: ExecutionContext) -> DAGValue: - input_val = self.get_single_input_value(context) - input_val.assert_format(DAGFormat.SPARK) + input_values = self.get_input_values(context) + for val in input_values: + val.assert_format(DAGFormat.SPARK) + + input_dfs: List[DataFrame] = [val.data for val in input_values] - transformed_df = self.udf(input_val.data) + transformed_df = self.udf(*input_dfs) return DAGValue( data=transformed_df, format=DAGFormat.SPARK, metadata={"transformed": True} diff --git a/sdk/python/feast/infra/compute_engines/utils.py b/sdk/python/feast/infra/compute_engines/utils.py index 09a13a72193..20a3dae981d 100644 --- a/sdk/python/feast/infra/compute_engines/utils.py +++ b/sdk/python/feast/infra/compute_engines/utils.py @@ -2,12 +2,13 @@ from typing import Optional from feast.data_source import DataSource -from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext from feast.infra.offline_stores.offline_store import RetrievalJob def create_offline_store_retrieval_job( data_source: DataSource, + column_info: ColumnInfo, context: ExecutionContext, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, @@ -16,6 +17,7 @@ def create_offline_store_retrieval_job( Create a retrieval job for the offline store. Args: data_source: The data source to pull from. + column_info: Column information containing join keys, feature columns, and timestamps. context: start_time: end_time: @@ -24,7 +26,6 @@ def create_offline_store_retrieval_job( """ offline_store = context.offline_store - column_info = context.column_info # ๐Ÿ“ฅ Reuse Feast's robust query resolver retrieval_job = offline_store.pull_all_from_table_or_query( config=context.repo_config, diff --git a/sdk/python/feast/protos/feast/core/FeatureView_pb2.py b/sdk/python/feast/protos/feast/core/FeatureView_pb2.py index d1456cf9faf..702335f1166 100644 --- a/sdk/python/feast/protos/feast/core/FeatureView_pb2.py +++ b/sdk/python/feast/protos/feast/core/FeatureView_pb2.py @@ -18,7 +18,7 @@ from feast.protos.feast.core import Feature_pb2 as feast_dot_core_dot_Feature__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66\x65\x61st/core/FeatureView.proto\x12\nfeast.core\x1a\x1egoogle/protobuf/duration.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1b\x66\x65\x61st/core/DataSource.proto\x1a\x18\x66\x65\x61st/core/Feature.proto\"c\n\x0b\x46\x65\x61tureView\x12)\n\x04spec\x18\x01 \x01(\x0b\x32\x1b.feast.core.FeatureViewSpec\x12)\n\x04meta\x18\x02 \x01(\x0b\x32\x1b.feast.core.FeatureViewMeta\"\xce\x03\n\x0f\x46\x65\x61tureViewSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07project\x18\x02 \x01(\t\x12\x10\n\x08\x65ntities\x18\x03 \x03(\t\x12+\n\x08\x66\x65\x61tures\x18\x04 \x03(\x0b\x32\x19.feast.core.FeatureSpecV2\x12\x31\n\x0e\x65ntity_columns\x18\x0c \x03(\x0b\x32\x19.feast.core.FeatureSpecV2\x12\x13\n\x0b\x64\x65scription\x18\n \x01(\t\x12\x33\n\x04tags\x18\x05 \x03(\x0b\x32%.feast.core.FeatureViewSpec.TagsEntry\x12\r\n\x05owner\x18\x0b \x01(\t\x12&\n\x03ttl\x18\x06 \x01(\x0b\x32\x19.google.protobuf.Duration\x12,\n\x0c\x62\x61tch_source\x18\x07 \x01(\x0b\x32\x16.feast.core.DataSource\x12-\n\rstream_source\x18\t \x01(\x0b\x32\x16.feast.core.DataSource\x12\x0e\n\x06online\x18\x08 \x01(\x08\x12\x0f\n\x07offline\x18\r \x01(\x08\x1a+\n\tTagsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xcc\x01\n\x0f\x46\x65\x61tureViewMeta\x12\x35\n\x11\x63reated_timestamp\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12:\n\x16last_updated_timestamp\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x46\n\x19materialization_intervals\x18\x03 \x03(\x0b\x32#.feast.core.MaterializationInterval\"w\n\x17MaterializationInterval\x12.\n\nstart_time\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12,\n\x08\x65nd_time\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"@\n\x0f\x46\x65\x61tureViewList\x12-\n\x0c\x66\x65\x61tureviews\x18\x01 \x03(\x0b\x32\x17.feast.core.FeatureViewBU\n\x10\x66\x65\x61st.proto.coreB\x10\x46\x65\x61tureViewProtoZ/github.com/feast-dev/feast/go/protos/feast/coreb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66\x65\x61st/core/FeatureView.proto\x12\nfeast.core\x1a\x1egoogle/protobuf/duration.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1b\x66\x65\x61st/core/DataSource.proto\x1a\x18\x66\x65\x61st/core/Feature.proto\"c\n\x0b\x46\x65\x61tureView\x12)\n\x04spec\x18\x01 \x01(\x0b\x32\x1b.feast.core.FeatureViewSpec\x12)\n\x04meta\x18\x02 \x01(\x0b\x32\x1b.feast.core.FeatureViewMeta\"\x81\x04\n\x0f\x46\x65\x61tureViewSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07project\x18\x02 \x01(\t\x12\x10\n\x08\x65ntities\x18\x03 \x03(\t\x12+\n\x08\x66\x65\x61tures\x18\x04 \x03(\x0b\x32\x19.feast.core.FeatureSpecV2\x12\x31\n\x0e\x65ntity_columns\x18\x0c \x03(\x0b\x32\x19.feast.core.FeatureSpecV2\x12\x13\n\x0b\x64\x65scription\x18\n \x01(\t\x12\x33\n\x04tags\x18\x05 \x03(\x0b\x32%.feast.core.FeatureViewSpec.TagsEntry\x12\r\n\x05owner\x18\x0b \x01(\t\x12&\n\x03ttl\x18\x06 \x01(\x0b\x32\x19.google.protobuf.Duration\x12,\n\x0c\x62\x61tch_source\x18\x07 \x01(\x0b\x32\x16.feast.core.DataSource\x12-\n\rstream_source\x18\t \x01(\x0b\x32\x16.feast.core.DataSource\x12\x0e\n\x06online\x18\x08 \x01(\x08\x12\x0f\n\x07offline\x18\r \x01(\x08\x12\x31\n\x0csource_views\x18\x0e \x03(\x0b\x32\x1b.feast.core.FeatureViewSpec\x1a+\n\tTagsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xcc\x01\n\x0f\x46\x65\x61tureViewMeta\x12\x35\n\x11\x63reated_timestamp\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12:\n\x16last_updated_timestamp\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x46\n\x19materialization_intervals\x18\x03 \x03(\x0b\x32#.feast.core.MaterializationInterval\"w\n\x17MaterializationInterval\x12.\n\nstart_time\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12,\n\x08\x65nd_time\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"@\n\x0f\x46\x65\x61tureViewList\x12-\n\x0c\x66\x65\x61tureviews\x18\x01 \x03(\x0b\x32\x17.feast.core.FeatureViewBU\n\x10\x66\x65\x61st.proto.coreB\x10\x46\x65\x61tureViewProtoZ/github.com/feast-dev/feast/go/protos/feast/coreb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -31,13 +31,13 @@ _globals['_FEATUREVIEW']._serialized_start=164 _globals['_FEATUREVIEW']._serialized_end=263 _globals['_FEATUREVIEWSPEC']._serialized_start=266 - _globals['_FEATUREVIEWSPEC']._serialized_end=728 - _globals['_FEATUREVIEWSPEC_TAGSENTRY']._serialized_start=685 - _globals['_FEATUREVIEWSPEC_TAGSENTRY']._serialized_end=728 - _globals['_FEATUREVIEWMETA']._serialized_start=731 - _globals['_FEATUREVIEWMETA']._serialized_end=935 - _globals['_MATERIALIZATIONINTERVAL']._serialized_start=937 - _globals['_MATERIALIZATIONINTERVAL']._serialized_end=1056 - _globals['_FEATUREVIEWLIST']._serialized_start=1058 - _globals['_FEATUREVIEWLIST']._serialized_end=1122 + _globals['_FEATUREVIEWSPEC']._serialized_end=779 + _globals['_FEATUREVIEWSPEC_TAGSENTRY']._serialized_start=736 + _globals['_FEATUREVIEWSPEC_TAGSENTRY']._serialized_end=779 + _globals['_FEATUREVIEWMETA']._serialized_start=782 + _globals['_FEATUREVIEWMETA']._serialized_end=986 + _globals['_MATERIALIZATIONINTERVAL']._serialized_start=988 + _globals['_MATERIALIZATIONINTERVAL']._serialized_end=1107 + _globals['_FEATUREVIEWLIST']._serialized_start=1109 + _globals['_FEATUREVIEWLIST']._serialized_end=1173 # @@protoc_insertion_point(module_scope) diff --git a/sdk/python/feast/protos/feast/core/FeatureView_pb2.pyi b/sdk/python/feast/protos/feast/core/FeatureView_pb2.pyi index 6abeb85e263..d93c9b8f80f 100644 --- a/sdk/python/feast/protos/feast/core/FeatureView_pb2.pyi +++ b/sdk/python/feast/protos/feast/core/FeatureView_pb2.pyi @@ -91,6 +91,7 @@ class FeatureViewSpec(google.protobuf.message.Message): STREAM_SOURCE_FIELD_NUMBER: builtins.int ONLINE_FIELD_NUMBER: builtins.int OFFLINE_FIELD_NUMBER: builtins.int + SOURCE_VIEWS_FIELD_NUMBER: builtins.int name: builtins.str """Name of the feature view. Must be unique. Not updated.""" project: builtins.str @@ -130,6 +131,8 @@ class FeatureViewSpec(google.protobuf.message.Message): """ offline: builtins.bool """Whether these features should be written to the offline store""" + @property + def source_views(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___FeatureViewSpec]: ... def __init__( self, *, @@ -146,9 +149,10 @@ class FeatureViewSpec(google.protobuf.message.Message): stream_source: feast.core.DataSource_pb2.DataSource | None = ..., online: builtins.bool = ..., offline: builtins.bool = ..., + source_views: collections.abc.Iterable[global___FeatureViewSpec] | None = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["batch_source", b"batch_source", "stream_source", b"stream_source", "ttl", b"ttl"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["batch_source", b"batch_source", "description", b"description", "entities", b"entities", "entity_columns", b"entity_columns", "features", b"features", "name", b"name", "offline", b"offline", "online", b"online", "owner", b"owner", "project", b"project", "stream_source", b"stream_source", "tags", b"tags", "ttl", b"ttl"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["batch_source", b"batch_source", "description", b"description", "entities", b"entities", "entity_columns", b"entity_columns", "features", b"features", "name", b"name", "offline", b"offline", "online", b"online", "owner", b"owner", "project", b"project", "source_views", b"source_views", "stream_source", b"stream_source", "tags", b"tags", "ttl", b"ttl"]) -> None: ... global___FeatureViewSpec = FeatureViewSpec diff --git a/sdk/python/feast/stream_feature_view.py b/sdk/python/feast/stream_feature_view.py index e3608b10354..dcbbd33df7c 100644 --- a/sdk/python/feast/stream_feature_view.py +++ b/sdk/python/feast/stream_feature_view.py @@ -69,6 +69,7 @@ class StreamFeatureView(FeatureView): entities: List[str] ttl: Optional[timedelta] source: DataSource + sink_source: Optional[DataSource] = None schema: List[Field] entity_columns: List[Field] features: List[Field] @@ -90,7 +91,8 @@ def __init__( self, *, name: str, - source: DataSource, + source: Union[DataSource, "StreamFeatureView", List["StreamFeatureView"]], + sink_source: Optional[DataSource] = None, entities: Optional[List[Entity]] = None, ttl: timedelta = timedelta(days=0), tags: Optional[Dict[str, str]] = None, @@ -114,7 +116,7 @@ def __init__( RuntimeWarning, ) - if ( + if isinstance(source, DataSource) and ( type(source).__name__ not in SUPPORTED_STREAM_SOURCES and source.to_proto().type != DataSourceProto.SourceType.CUSTOM_SOURCE ): @@ -148,7 +150,8 @@ def __init__( description=description, owner=owner, schema=schema, - source=source, + source=source, # type: ignore[arg-type] + sink_source=sink_source, ) def get_feature_transformation(self) -> Optional[Transformation]: diff --git a/sdk/python/tests/integration/compute_engines/spark/test_compute.py b/sdk/python/tests/integration/compute_engines/spark/test_compute.py index 3d44a130d64..e0855ae31f3 100644 --- a/sdk/python/tests/integration/compute_engines/spark/test_compute.py +++ b/sdk/python/tests/integration/compute_engines/spark/test_compute.py @@ -1,15 +1,13 @@ -from datetime import datetime, timedelta +from datetime import timedelta from typing import cast from unittest.mock import MagicMock -import pandas as pd import pytest from pyspark.sql import DataFrame from tqdm import tqdm -from feast import BatchFeatureView, Entity, Field +from feast import BatchFeatureView, Field from feast.aggregation import Aggregation -from feast.data_source import DataSource from feast.infra.common.materialization_job import ( MaterializationJobStatus, MaterializationTask, @@ -20,101 +18,18 @@ from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( SparkOfflineStore, ) -from feast.infra.offline_stores.contrib.spark_offline_store.tests.data_source import ( - SparkDataSourceCreator, -) from feast.types import Float32, Int32, Int64 -from tests.integration.feature_repos.integration_test_repo_config import ( - IntegrationTestRepoConfig, -) -from tests.integration.feature_repos.repo_configuration import ( - construct_test_environment, -) -from tests.integration.feature_repos.universal.online_store.redis import ( - RedisOnlineStoreCreator, -) - -now = datetime.now() -today = datetime.today() - -driver = Entity( - name="driver_id", - description="driver id", +from tests.integration.compute_engines.spark.utils import ( + _check_offline_features, + _check_online_features, + create_entity_df, + create_feature_dataset, + create_spark_environment, + driver, + now, ) -def create_feature_dataset(spark_environment) -> DataSource: - yesterday = today - timedelta(days=1) - last_week = today - timedelta(days=7) - df = pd.DataFrame( - [ - { - "driver_id": 1001, - "event_timestamp": yesterday, - "created": now - timedelta(hours=2), - "conv_rate": 0.8, - "acc_rate": 0.5, - "avg_daily_trips": 15, - }, - { - "driver_id": 1001, - "event_timestamp": last_week, - "created": now - timedelta(hours=3), - "conv_rate": 0.75, - "acc_rate": 0.9, - "avg_daily_trips": 14, - }, - { - "driver_id": 1002, - "event_timestamp": yesterday, - "created": now - timedelta(hours=2), - "conv_rate": 0.7, - "acc_rate": 0.4, - "avg_daily_trips": 12, - }, - { - "driver_id": 1002, - "event_timestamp": yesterday - timedelta(days=1), - "created": now - timedelta(hours=2), - "conv_rate": 0.3, - "acc_rate": 0.6, - "avg_daily_trips": 12, - }, - ] - ) - ds = spark_environment.data_source_creator.create_data_source( - df, - spark_environment.feature_store.project, - timestamp_field="event_timestamp", - created_timestamp_column="created", - ) - return ds - - -def create_entity_df() -> pd.DataFrame: - entity_df = pd.DataFrame( - [ - {"driver_id": 1001, "event_timestamp": today}, - {"driver_id": 1002, "event_timestamp": today}, - ] - ) - return entity_df - - -def create_spark_environment(): - spark_config = IntegrationTestRepoConfig( - provider="local", - online_store_creator=RedisOnlineStoreCreator, - offline_store_creator=SparkDataSourceCreator, - batch_engine={"type": "spark.engine", "partitions": 10}, - ) - spark_environment = construct_test_environment( - spark_config, None, entity_key_serialization_version=3 - ) - spark_environment.setup() - return spark_environment - - @pytest.mark.integration def test_spark_compute_engine_get_historical_features(): spark_environment = create_spark_environment() @@ -123,8 +38,8 @@ def test_spark_compute_engine_get_historical_features(): data_source = create_feature_dataset(spark_environment) def transform_feature(df: DataFrame) -> DataFrame: - df = df.withColumn("sum_conv_rate", df["sum_conv_rate"] * 2) - df = df.withColumn("avg_acc_rate", df["avg_acc_rate"] * 2) + df = df.withColumn("conv_rate", df["conv_rate"] * 2) + df = df.withColumn("acc_rate", df["acc_rate"] * 2) return df driver_stats_fv = BatchFeatureView( @@ -176,9 +91,9 @@ def transform_feature(df: DataFrame) -> DataFrame: # โœ… Assert output assert df_out.driver_id.to_list() == [1001, 1002] - assert abs(df_out["sum_conv_rate"].to_list()[0] - 1.6) < 1e-6 + assert abs(df_out["sum_conv_rate"].to_list()[0] - 3.1) < 1e-6 assert abs(df_out["sum_conv_rate"].to_list()[1] - 2.0) < 1e-6 - assert abs(df_out["avg_acc_rate"].to_list()[0] - 1.0) < 1e-6 + assert abs(df_out["avg_acc_rate"].to_list()[0] - 1.4) < 1e-6 assert abs(df_out["avg_acc_rate"].to_list()[1] - 1.0) < 1e-6 finally: @@ -277,39 +192,5 @@ def tqdm_builder(length): spark_environment.teardown() -def _check_online_features( - fs, - driver_id, - feature, - expected_value, - full_feature_names: bool = True, -): - online_response = fs.get_online_features( - features=[feature], - entity_rows=[{"driver_id": driver_id}], - full_feature_names=full_feature_names, - ).to_dict() - - feature_ref = "__".join(feature.split(":")) - - assert len(online_response["driver_id"]) == 1 - assert online_response["driver_id"][0] == driver_id - assert abs(online_response[feature_ref][0] - expected_value < 1e-6), ( - "Transformed result" - ) - - -def _check_offline_features( - fs, - feature, - entity_df, -): - offline_df = fs.get_historical_features( - entity_df=entity_df, - features=[feature], - ).to_df() - assert len(offline_df) == 4 - - if __name__ == "__main__": test_spark_compute_engine_get_historical_features() diff --git a/sdk/python/tests/integration/compute_engines/spark/test_compute_dag.py b/sdk/python/tests/integration/compute_engines/spark/test_compute_dag.py new file mode 100644 index 00000000000..24277d9a323 --- /dev/null +++ b/sdk/python/tests/integration/compute_engines/spark/test_compute_dag.py @@ -0,0 +1,224 @@ +from datetime import timedelta +from unittest.mock import MagicMock + +import pytest +from pyspark.sql import DataFrame +from tqdm import tqdm + +from feast import BatchFeatureView, Field +from feast.aggregation import Aggregation +from feast.infra.common.materialization_job import ( + MaterializationJobStatus, + MaterializationTask, +) +from feast.infra.compute_engines.spark.compute import SparkComputeEngine +from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( + SparkOfflineStore, +) +from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import ( + SparkSource, +) +from feast.types import Float32, Int32, Int64 +from tests.integration.compute_engines.spark.utils import ( + _check_offline_features, + _check_online_features, + create_entity_df, + create_feature_dataset, + create_spark_environment, + driver, + now, +) + + +def create_base_feature_view(source): + return BatchFeatureView( + name="hourly_driver_stats", + entities=[driver], + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + Field(name="driver_id", dtype=Int32), + ], + online=True, + offline=True, + source=source, + ) + + +def create_agg_feature_view(source): + return BatchFeatureView( + name="agg_hourly_driver_stats", + entities=[driver], + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + Field(name="driver_id", dtype=Int32), + ], + online=True, + offline=True, + source=source, + aggregations=[ + Aggregation(column="conv_rate", function="sum"), + Aggregation(column="acc_rate", function="avg"), + ], + ) + + +def create_chained_feature_view(base_fv: BatchFeatureView): + def transform_feature(df: DataFrame) -> DataFrame: + df = df.withColumn("conv_rate", df["conv_rate"] * 2) + df = df.withColumn("acc_rate", df["acc_rate"] * 2) + return df + + return BatchFeatureView( + name="daily_driver_stats", + entities=[driver], + udf=transform_feature, + udf_string="transform", + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="driver_id", dtype=Int32), + ], + online=True, + offline=True, + source=base_fv, + sink_source=SparkSource( + name="daily_driver_stats_sink", + path="/tmp/daily_driver_stats_sink", + file_format="parquet", + timestamp_field="event_timestamp", + created_timestamp_column="created", + ), + ) + + +@pytest.mark.integration +def test_spark_dag_materialize_recursive_view(): + spark_env = create_spark_environment() + fs = spark_env.feature_store + registry = fs.registry + source = create_feature_dataset(spark_env) + + base_fv = create_base_feature_view(source) + chained_fv = create_chained_feature_view(base_fv) + + def tqdm_builder(length): + return tqdm(total=length, ncols=100) + + try: + fs.apply([driver, base_fv, chained_fv]) + + # ๐Ÿงช Materialize top-level view; DAG will include base_fv implicitly + task = MaterializationTask( + project=fs.project, + feature_view=chained_fv, + start_time=now - timedelta(days=2), + end_time=now, + tqdm_builder=tqdm_builder, + ) + + engine = SparkComputeEngine( + repo_config=spark_env.config, + offline_store=SparkOfflineStore(), + online_store=MagicMock(), + registry=registry, + ) + + jobs = engine.materialize(registry, task) + + # โœ… Validate jobs ran + assert len(jobs) == 1 + assert jobs[0].status() == MaterializationJobStatus.SUCCEEDED + + _check_online_features( + fs=fs, + driver_id=1001, + feature="daily_driver_stats:conv_rate", + expected_value=1.6, + full_feature_names=True, + ) + + entity_df = create_entity_df() + + _check_offline_features( + fs=fs, feature="hourly_driver_stats:conv_rate", entity_df=entity_df, size=2 + ) + finally: + spark_env.teardown() + + +@pytest.mark.integration +def test_spark_dag_materialize_multi_views(): + spark_env = create_spark_environment() + fs = spark_env.feature_store + registry = fs.registry + source = create_feature_dataset(spark_env) + + base_fv = create_base_feature_view(source) + chained_fv = create_chained_feature_view(base_fv) + + multi_view = BatchFeatureView( + name="multi_view", + entities=[driver], + schema=[ + Field(name="driver_id", dtype=Int32), + Field(name="daily_driver_stats__conv_rate", dtype=Float32), + Field(name="daily_driver_stats__acc_rate", dtype=Float32), + ], + online=True, + offline=True, + source=[base_fv, chained_fv], + sink_source=SparkSource( + name="multi_view_sink", + path="/tmp/multi_view_sink", + file_format="parquet", + timestamp_field="daily_driver_stats__event_timestamp", + created_timestamp_column="daily_driver_stats__created", + ), + ) + + def tqdm_builder(length): + return tqdm(total=length, ncols=100) + + try: + fs.apply([driver, base_fv, chained_fv, multi_view]) + + # ๐Ÿงช Materialize multi-view + task = MaterializationTask( + project=fs.project, + feature_view=multi_view, + start_time=now - timedelta(days=2), + end_time=now, + tqdm_builder=tqdm_builder, + ) + + engine = SparkComputeEngine( + repo_config=spark_env.config, + offline_store=SparkOfflineStore(), + online_store=MagicMock(), + registry=registry, + ) + + jobs = engine.materialize(registry, task) + + # โœ… Validate jobs ran + assert len(jobs) == 1 + assert jobs[0].status() == MaterializationJobStatus.SUCCEEDED + + _check_online_features( + fs=fs, + driver_id=1001, + feature="multi_view:daily_driver_stats__conv_rate", + expected_value=1.6, + full_feature_names=True, + ) + + entity_df = create_entity_df() + + _check_offline_features( + fs=fs, feature="hourly_driver_stats:conv_rate", entity_df=entity_df, size=2 + ) + finally: + spark_env.teardown() diff --git a/sdk/python/tests/integration/compute_engines/spark/utils.py b/sdk/python/tests/integration/compute_engines/spark/utils.py new file mode 100644 index 00000000000..20ffba4eff1 --- /dev/null +++ b/sdk/python/tests/integration/compute_engines/spark/utils.py @@ -0,0 +1,133 @@ +from datetime import datetime, timedelta + +import pandas as pd + +from feast import Entity +from feast.data_source import DataSource +from feast.infra.offline_stores.contrib.spark_offline_store.tests.data_source import ( + SparkDataSourceCreator, +) +from tests.integration.feature_repos.integration_test_repo_config import ( + IntegrationTestRepoConfig, +) +from tests.integration.feature_repos.repo_configuration import ( + construct_test_environment, +) +from tests.integration.feature_repos.universal.online_store.redis import ( + RedisOnlineStoreCreator, +) + +now = datetime.now() +today = datetime.today() + +driver = Entity( + name="driver_id", + description="driver id", +) + + +def create_entity_df() -> pd.DataFrame: + entity_df = pd.DataFrame( + [ + {"driver_id": 1001, "event_timestamp": today}, + {"driver_id": 1002, "event_timestamp": today}, + ] + ) + return entity_df + + +def create_feature_dataset(spark_environment) -> DataSource: + yesterday = today - timedelta(days=1) + last_week = today - timedelta(days=7) + df = pd.DataFrame( + [ + { + "driver_id": 1001, + "event_timestamp": yesterday, + "created": now - timedelta(hours=2), + "conv_rate": 0.8, + "acc_rate": 0.5, + "avg_daily_trips": 15, + }, + { + "driver_id": 1001, + "event_timestamp": last_week, + "created": now - timedelta(hours=3), + "conv_rate": 0.75, + "acc_rate": 0.9, + "avg_daily_trips": 14, + }, + { + "driver_id": 1002, + "event_timestamp": yesterday, + "created": now - timedelta(hours=2), + "conv_rate": 0.7, + "acc_rate": 0.4, + "avg_daily_trips": 12, + }, + { + "driver_id": 1002, + "event_timestamp": yesterday - timedelta(days=1), + "created": now - timedelta(hours=2), + "conv_rate": 0.3, + "acc_rate": 0.6, + "avg_daily_trips": 12, + }, + ] + ) + ds = spark_environment.data_source_creator.create_data_source( + df, + spark_environment.feature_store.project, + timestamp_field="event_timestamp", + created_timestamp_column="created", + ) + return ds + + +def create_spark_environment(): + spark_config = IntegrationTestRepoConfig( + provider="local", + online_store_creator=RedisOnlineStoreCreator, + offline_store_creator=SparkDataSourceCreator, + batch_engine={"type": "spark.engine", "partitions": 10}, + ) + spark_environment = construct_test_environment( + spark_config, None, entity_key_serialization_version=3 + ) + spark_environment.setup() + return spark_environment + + +def _check_online_features( + fs, + driver_id, + feature, + expected_value, + full_feature_names: bool = True, +): + online_response = fs.get_online_features( + features=[feature], + entity_rows=[{"driver_id": driver_id}], + full_feature_names=full_feature_names, + ).to_dict() + + feature_ref = "__".join(feature.split(":")) + + assert len(online_response["driver_id"]) == 1 + assert online_response["driver_id"][0] == driver_id + assert abs(online_response[feature_ref][0] - expected_value < 1e-6), ( + "Transformed result" + ) + + +def _check_offline_features( + fs, + feature, + entity_df, + size: int = 4, +): + offline_df = fs.get_historical_features( + entity_df=entity_df, + features=[feature], + ).to_df() + assert len(offline_df) == size diff --git a/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py index c486b4148fc..20e23c35e03 100644 --- a/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py +++ b/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py @@ -45,12 +45,6 @@ def create_context(node_outputs): entity_defs=MagicMock(), entity_df=entity_df, node_outputs=node_outputs, - column_info=ColumnInfo( - join_keys=["entity_id"], - feature_cols=["value"], - ts_col="event_timestamp", - created_ts_col=None, - ), ) @@ -64,6 +58,12 @@ def test_local_filter_node(): name="filter", backend=backend, filter_expr="value > 15", + column_info=ColumnInfo( + join_keys=["entity_id"], + feature_cols=["value"], + ts_col="event_timestamp", + created_ts_col=None, + ), ) filter_node.add_input(MagicMock()) filter_node.inputs[0].name = "source" @@ -110,6 +110,12 @@ def test_local_join_node(): join_node = LocalJoinNode( name="join", backend=backend, + column_info=ColumnInfo( + join_keys=["entity_id"], + feature_cols=["value"], + ts_col="event_timestamp", + created_ts_col=None, + ), ) join_node.add_input(MagicMock()) join_node.inputs[0].name = "source" @@ -156,7 +162,16 @@ def test_local_dedup_node(): context.entity_timestamp_col = "event_timestamp" # Build node - node = LocalDedupNode(name="dedup", backend=backend) + node = LocalDedupNode( + name="dedup", + backend=backend, + column_info=ColumnInfo( + join_keys=["entity_id"], + feature_cols=["value"], + ts_col="event_timestamp", + created_ts_col="created_ts", + ), + ) node.add_input(MagicMock()) node.inputs[0].name = "source" diff --git a/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py index 3f681017e89..61824074ae1 100644 --- a/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py +++ b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py @@ -58,19 +58,17 @@ def strip_extra_spaces(df): online_store=MagicMock(), entity_defs=MagicMock(), entity_df=None, - column_info=ColumnInfo( - join_keys=["name"], - feature_cols=["age"], - ts_col="", - created_ts_col="", - ), node_outputs={"source": input_value}, ) + # Prepare mock input node + input_node = MagicMock() + input_node.name = "source" + # Create and run the node - node = SparkTransformationNode("transform", udf=strip_extra_spaces) - node.add_input(MagicMock()) - node.inputs[0].name = "source" + node = SparkTransformationNode( + "transform", udf=strip_extra_spaces, inputs=[input_node] + ) result = node.execute(context) # Assert output @@ -104,12 +102,6 @@ def test_spark_aggregation_node_executes_correctly(spark_session): online_store=MagicMock(), entity_defs=[], entity_df=None, - column_info=ColumnInfo( - join_keys=["user_id"], - feature_cols=["value"], - ts_col="", - created_ts_col="", - ), node_outputs={"source": input_value}, ) @@ -188,23 +180,26 @@ def test_spark_join_node_executes_point_in_time_join(spark_session): entity_defs=[driver], entity_df=entity_df, node_outputs={ - "feature_node": feature_val, + "source": feature_val, }, - column_info=ColumnInfo( - join_keys=["driver_id"], - feature_cols=["conv_rate", "acc_rate", "avg_daily_trips"], - ts_col="event_timestamp", - created_ts_col="created", - ), ) + # Prepare mock input node + input_node = MagicMock() + input_node.name = "source" + # Create the node and add input join_node = SparkJoinNode( name="join", spark_session=spark_session, + inputs=[input_node], + column_info=ColumnInfo( + join_keys=["driver_id"], + feature_cols=["conv_rate", "acc_rate", "avg_daily_trips"], + ts_col="event_timestamp", + created_ts_col="created", + ), ) - join_node.add_input(MagicMock()) - join_node.inputs[0].name = "feature_node" # Execute the node output = join_node.execute(context) @@ -213,6 +208,16 @@ def test_spark_join_node_executes_point_in_time_join(spark_session): dedup_node = SparkDedupNode( name="dedup", spark_session=spark_session, + column_info=ColumnInfo( + join_keys=["driver_id"], + feature_cols=[ + "source__conv_rate", + "source__acc_rate", + "source__avg_daily_trips", + ], + ts_col="source__event_timestamp", + created_ts_col="source__created", + ), ) dedup_node.add_input(MagicMock()) dedup_node.inputs[0].name = "join" @@ -225,10 +230,10 @@ def test_spark_join_node_executes_point_in_time_join(spark_session): # Validate result for driver_id = 1001 assert result_df[0]["driver_id"] == 1001 - assert abs(result_df[0]["conv_rate"] - 0.8) < 1e-6 - assert result_df[0]["avg_daily_trips"] == 15 + assert abs(result_df[0]["source__conv_rate"] - 0.8) < 1e-6 + assert result_df[0]["source__avg_daily_trips"] == 15 # Validate result for driver_id = 1002 assert result_df[1]["driver_id"] == 1002 - assert abs(result_df[1]["conv_rate"] - 0.7) < 1e-6 - assert result_df[1]["avg_daily_trips"] == 12 + assert abs(result_df[1]["source__conv_rate"] - 0.7) < 1e-6 + assert result_df[1]["source__avg_daily_trips"] == 12 diff --git a/sdk/python/tests/unit/infra/compute_engines/test_feature_builder.py b/sdk/python/tests/unit/infra/compute_engines/test_feature_builder.py new file mode 100644 index 00000000000..b78ef15299c --- /dev/null +++ b/sdk/python/tests/unit/infra/compute_engines/test_feature_builder.py @@ -0,0 +1,145 @@ +from unittest.mock import MagicMock + +from feast.data_source import DataSource +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.model import DAGFormat +from feast.infra.compute_engines.dag.node import DAGNode +from feast.infra.compute_engines.dag.plan import ExecutionPlan +from feast.infra.compute_engines.dag.value import DAGValue +from feast.infra.compute_engines.feature_builder import FeatureBuilder + +# --------------------------- +# Minimal Mock DAGNode for testing +# --------------------------- + + +class MockDAGNode(DAGNode): + def __init__(self, name, inputs=None): + super().__init__(name, inputs=inputs or []) + + def execute(self, context: ExecutionContext) -> DAGValue: + return DAGValue(data=None, format=DAGFormat.SPARK, metadata={}) + + +# --------------------------- +# Mock Feature View Definitions +# --------------------------- + + +class MockFeatureView: + def __init__( + self, + name, + source=None, + aggregations=None, + feature_transformation=None, + ): + self.name = name + self.source = source + + # Internal resolution (emulating what real FeatureView.__init__ would do) + self.data_source = source if isinstance(source, DataSource) else None + self.source_views = ( + [source] + if isinstance(source, MockFeatureView) + else source + if isinstance(source, list) + else [] + ) + + self.aggregations = aggregations or [] + self.feature_transformation = feature_transformation + self.ttl = None + self.filter = None + self.enable_validation = False + self.entities = ["driver_id"] + self.batch_source = type("BatchSource", (), {"timestamp_field": "ts"}) + self.stream_source = None + self.tags = {} + + +class MockTransformation: + def __init__(self, name): + self.name = name + self.udf = lambda df: df + + +mock_source = MagicMock(spec=DataSource) + +# --------------------------- +# Mock DAG +# --------------------------- + +hourly_driver_stats = MockFeatureView( + name="hourly_driver_stats", + source=mock_source, + aggregations=[{"function": "sum", "column": "trips"}], + feature_transformation=MockTransformation("hourly_tf"), +) + +daily_driver_stats = MockFeatureView( + name="daily_driver_stats", + source=hourly_driver_stats, + aggregations=[{"function": "mean", "column": "trips"}], + feature_transformation=MockTransformation("daily_tf"), +) + + +# --------------------------- +# Mock FeatureBuilder +# --------------------------- + + +class MockFeatureBuilder(FeatureBuilder): + def __init__(self, feature_view): + super().__init__( + registry=MagicMock(), feature_view=feature_view, task=MagicMock() + ) + + def build_source_node(self, view): + return MockDAGNode(f"Source({view.name})") + + def build_join_node(self, view, input_nodes): + return MockDAGNode(f"Join({view.name})", inputs=input_nodes) + + def build_filter_node(self, view, input_node): + return MockDAGNode(f"Filter({view.name})", inputs=[input_node]) + + def build_aggregation_node(self, view, input_node): + return MockDAGNode(f"Agg({view.name})", inputs=[input_node]) + + def build_dedup_node(self, view, input_node): + return MockDAGNode(f"Dedup({view.name})", inputs=[input_node]) + + def build_transformation_node(self, view, input_nodes): + return MockDAGNode(f"Transform({view.name})", inputs=input_nodes) + + def build_validation_node(self, view, input_node): + return MockDAGNode(f"Validate({view.name})", inputs=[input_node]) + + def build_output_nodes(self, view, final_node): + output_node = MockDAGNode(f"Output({final_node.name})", inputs=[final_node]) + self.nodes.append(output_node) + return output_node + + +# --------------------------- +# Test +# --------------------------- + + +def test_recursive_featureview_build(): + builder = MockFeatureBuilder(daily_driver_stats) + execution_plan: ExecutionPlan = builder.build() + + expected_output = """\ +- Output(Agg(daily_driver_stats)) + - Agg(daily_driver_stats) + - Filter(daily_driver_stats) + - Transform(daily_driver_stats) + - Agg(hourly_driver_stats) + - Filter(hourly_driver_stats) + - Transform(hourly_driver_stats) + - Source(hourly_driver_stats)""" + + assert execution_plan.to_dag() == expected_output diff --git a/sdk/python/tests/unit/permissions/conftest.py b/sdk/python/tests/unit/permissions/conftest.py index ba277d13b49..fceb9f0b197 100644 --- a/sdk/python/tests/unit/permissions/conftest.py +++ b/sdk/python/tests/unit/permissions/conftest.py @@ -1,8 +1,9 @@ -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock import pytest from feast import FeatureView +from feast.data_source import DataSource from feast.entity import Entity from feast.infra.registry.base_registry import BaseRegistry from feast.permissions.decorator import require_permissions @@ -17,9 +18,14 @@ class SecuredFeatureView(FeatureView): def __init__(self, name, tags): + mock_source = MagicMock(spec=DataSource) + mock_source.created_timestamp_column = None + mock_source.timestamp_field = None + mock_source.date_partition_column = None + super().__init__( name=name, - source=Mock(), + source=mock_source, tags=tags, )