diff --git a/Makefile b/Makefile index b8a34855fbf..20220164e87 100644 --- a/Makefile +++ b/Makefile @@ -302,6 +302,29 @@ test-python-universal-postgres-offline: ## Run Python Postgres integration tests not test_spark" \ sdk/python/tests +test-python-universal-ray-offline: ## Run Python Ray offline store integration tests + PYTHONPATH='.' \ + FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.offline_stores.contrib.ray_repo_configuration \ + PYTEST_PLUGINS=sdk.python.feast.infra.offline_stores.contrib.ray_offline_store.tests \ + python -m pytest -n 8 --integration \ + -m "not universal_online_stores and not benchmark" \ + -k "not test_historical_retrieval_with_validation and \ + not test_universal_cli and \ + not test_go_feature_server and \ + not test_feature_logging and \ + not test_logged_features_validation and \ + not test_lambda_materialization_consistency and \ + not gcs_registry and \ + not s3_registry and \ + not test_snowflake and \ + not test_spark" \ + sdk/python/tests + +test-python-ray-compute-engine: ## Run Python Ray compute engine tests + PYTHONPATH='.' \ + python -m pytest -v --integration \ + sdk/python/tests/integration/compute_engines/ray_compute/ + test-python-universal-postgres-online: ## Run Python Postgres integration tests PYTHONPATH='.' \ FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.postgres_online_store.postgres_repo_configuration \ diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 05ddc3f7be7..2e34687d6c7 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -107,6 +107,7 @@ * [Trino (contrib)](reference/offline-stores/trino.md) * [Azure Synapse + Azure SQL (contrib)](reference/offline-stores/mssql.md) * [Clickhouse (contrib)](reference/offline-stores/clickhouse.md) + * [Ray (contrib)](reference/offline-stores/ray.md) * [Remote Offline](reference/offline-stores/remote-offline-store.md) * [Online stores](reference/online-stores/README.md) * [Overview](reference/online-stores/overview.md) @@ -143,6 +144,7 @@ * [Snowflake](reference/compute-engine/snowflake.md) * [AWS Lambda (alpha)](reference/compute-engine/lambda.md) * [Spark (contrib)](reference/compute-engine/spark.md) + * [Ray (contrib)](reference/compute-engine/ray.md) * [Feature repository](reference/feature-repository/README.md) * [feature\_store.yaml](reference/feature-repository/feature-store-yaml.md) * [.feastignore](reference/feature-repository/feast-ignore.md) diff --git a/docs/reference/compute-engine/README.md b/docs/reference/compute-engine/README.md index c4a2f87f54d..dad2ede75a6 100644 --- a/docs/reference/compute-engine/README.md +++ b/docs/reference/compute-engine/README.md @@ -57,6 +57,14 @@ An example of built output from FeatureBuilder: - Supports point-in-time joins and large-scale materialization - Integrates with `SparkOfflineStore` and `SparkMaterializationJob` +### โšก RayComputeEngine (contrib) + +- Distributed DAG execution via Ray +- Intelligent join strategies (broadcast vs distributed) +- Automatic resource management and optimization +- Integrates with `RayOfflineStore` and `RayMaterializationJob` +- See [Ray Compute Engine documentation](ray.md) for details + ### ๐Ÿงช LocalComputeEngine {% page-ref page="local.md" %} diff --git a/docs/reference/compute-engine/ray.md b/docs/reference/compute-engine/ray.md new file mode 100644 index 00000000000..4ecc449e40b --- /dev/null +++ b/docs/reference/compute-engine/ray.md @@ -0,0 +1,377 @@ +# Ray Compute Engine (contrib) + +The Ray compute engine is a distributed compute implementation that leverages [Ray](https://www.ray.io/) for executing feature pipelines including transformations, aggregations, joins, and materializations. It provides scalable and efficient distributed processing for both `materialize()` and `get_historical_features()` operations. + +## Overview + +The Ray compute engine provides: +- **Distributed DAG Execution**: Executes feature computation DAGs across Ray clusters +- **Intelligent Join Strategies**: Automatic selection between broadcast and distributed joins +- **Lazy Evaluation**: Deferred execution for optimal performance +- **Resource Management**: Automatic scaling and resource optimization +- **Point-in-Time Joins**: Efficient temporal joins for historical feature retrieval + +## Architecture + +The Ray compute engine follows Feast's DAG-based architecture: + +``` +EntityDF โ†’ RayReadNode โ†’ RayJoinNode โ†’ RayFilterNode โ†’ RayAggregationNode โ†’ RayTransformationNode โ†’ Output +``` + +### Core Components + +| Component | Description | +|-----------|-------------| +| `RayComputeEngine` | Main engine implementing `ComputeEngine` interface | +| `RayFeatureBuilder` | Constructs DAG from Feature View definitions | +| `RayDAGNode` | Ray-specific DAG node implementations | +| `RayDAGRetrievalJob` | Executes retrieval plans and returns results | +| `RayMaterializationJob` | Handles materialization job tracking | + +## Configuration + +Configure the Ray compute engine in your `feature_store.yaml`: + +```yaml +project: my_project +registry: data/registry.db +provider: local +offline_store: + type: ray + storage_path: data/ray_storage +batch_engine: + type: ray.engine + max_workers: 4 # Optional: Maximum number of workers + enable_optimization: true # Optional: Enable performance optimizations + broadcast_join_threshold_mb: 100 # Optional: Broadcast join threshold (MB) + max_parallelism_multiplier: 2 # Optional: Parallelism multiplier + target_partition_size_mb: 64 # Optional: Target partition size (MB) + window_size_for_joins: "1H" # Optional: Time window for distributed joins + ray_address: localhost:10001 # Optional: Ray cluster address +``` + +### Configuration Options + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `type` | string | `"ray.engine"` | Must be `ray.engine` | +| `max_workers` | int | None (uses all cores) | Maximum number of Ray workers | +| `enable_optimization` | boolean | true | Enable performance optimizations | +| `broadcast_join_threshold_mb` | int | 100 | Size threshold for broadcast joins (MB) | +| `max_parallelism_multiplier` | int | 2 | Parallelism as multiple of CPU cores | +| `target_partition_size_mb` | int | 64 | Target partition size (MB) | +| `window_size_for_joins` | string | "1H" | Time window for distributed joins | +| `ray_address` | string | None | Ray cluster address (None = local Ray) | +| `enable_distributed_joins` | boolean | true | Enable distributed joins for large datasets | +| `staging_location` | string | None | Remote path for batch materialization jobs | +| `ray_conf` | dict | None | Ray configuration parameters | +| `execution_timeout_seconds` | int | None | Timeout for job execution in seconds | + +## Usage Examples + +### Basic Historical Feature Retrieval + +```python +from feast import FeatureStore +import pandas as pd +from datetime import datetime + +# Initialize feature store with Ray compute engine +store = FeatureStore("feature_store.yaml") + +# Create entity DataFrame +entity_df = pd.DataFrame({ + "driver_id": [1, 2, 3, 4, 5], + "event_timestamp": [datetime.now()] * 5 +}) + +# Get historical features using Ray compute engine +features = store.get_historical_features( + entity_df=entity_df, + features=[ + "driver_stats:avg_daily_trips", + "driver_stats:total_distance" + ] +) + +# Convert to DataFrame +df = features.to_df() +print(f"Retrieved {len(df)} rows with {len(df.columns)} columns") +``` + +### Batch Materialization + +```python +from datetime import datetime, timedelta + +# Materialize features using Ray compute engine +store.materialize( + start_date=datetime.now() - timedelta(days=7), + end_date=datetime.now(), + feature_views=["driver_stats", "customer_stats"] +) + +# The Ray compute engine handles: +# - Distributed data processing +# - Optimal join strategies +# - Resource management +# - Progress tracking +``` + +### Large-Scale Feature Retrieval + +```python +# Handle large entity datasets efficiently +large_entity_df = pd.DataFrame({ + "driver_id": range(1, 1000000), # 1M entities + "event_timestamp": [datetime.now()] * 1000000 +}) + +# Ray compute engine automatically: +# - Partitions data optimally +# - Selects appropriate join strategies +# - Distributes computation across cluster +features = store.get_historical_features( + entity_df=large_entity_df, + features=[ + "driver_stats:avg_daily_trips", + "driver_stats:total_distance", + "customer_stats:lifetime_value" + ] +).to_df() +``` + +### Advanced Configuration + +```yaml +# Production-ready configuration +batch_engine: + type: ray.engine + # Resource configuration + max_workers: 16 + max_parallelism_multiplier: 4 + + # Performance optimization + enable_optimization: true + broadcast_join_threshold_mb: 50 + target_partition_size_mb: 128 + + # Distributed join configuration + window_size_for_joins: "30min" + + # Ray cluster configuration + ray_address: "ray://head-node:10001" +``` + +### Complete Example Configuration + +Here's a complete example configuration showing how to use Ray offline store with Ray compute engine: + +```yaml +# Complete example configuration for Ray offline store + Ray compute engine +# This shows how to use both components together for distributed processing + +project: my_feast_project +registry: data/registry.db +provider: local + +# Ray offline store configuration +# Handles data I/O operations (reading/writing data) +offline_store: + type: ray + storage_path: s3://my-bucket/feast-data # Optional: Path for storing datasets + ray_address: localhost:10001 # Optional: Ray cluster address + +# Ray compute engine configuration +# Handles complex feature computation and distributed processing +batch_engine: + type: ray.engine + + # Resource configuration + max_workers: 8 # Maximum number of Ray workers + max_parallelism_multiplier: 2 # Parallelism as multiple of CPU cores + + # Performance optimization + enable_optimization: true # Enable performance optimizations + broadcast_join_threshold_mb: 100 # Broadcast join threshold (MB) + target_partition_size_mb: 64 # Target partition size (MB) + + # Distributed join configuration + window_size_for_joins: "1H" # Time window for distributed joins + + # Ray cluster configuration (inherits from offline_store if not specified) + ray_address: localhost:10001 # Ray cluster address +``` + +## DAG Node Types + +The Ray compute engine implements several specialized DAG nodes: + +### RayReadNode + +Reads data from Ray-compatible sources: +- Supports Parquet, CSV, and other formats +- Handles partitioning and schema inference +- Applies field mappings and filters + +### RayJoinNode + +Performs distributed joins: +- **Broadcast Join**: For small datasets (<100MB) +- **Distributed Join**: For large datasets with time-based windowing +- **Automatic Strategy Selection**: Based on dataset size and cluster resources + +### RayFilterNode + +Applies filters and time-based constraints: +- TTL-based filtering +- Timestamp range filtering +- Custom predicate filtering + +### RayAggregationNode + +Handles feature aggregations: +- Windowed aggregations +- Grouped aggregations +- Custom aggregation functions + +### RayTransformationNode + +Applies feature transformations: +- Row-level transformations +- Column-level transformations +- Custom transformation functions + +### RayWriteNode + +Writes results to various targets: +- Online stores +- Offline stores +- Temporary storage + +## Join Strategies + +The Ray compute engine automatically selects optimal join strategies: + +### Broadcast Join + +Used for small feature datasets: +- Automatically selected when feature data < 100MB +- Features are cached in Ray's object store +- Entities are distributed across cluster +- Each worker gets a copy of feature data + +### Distributed Windowed Join + +Used for large feature datasets: +- Automatically selected when feature data > 100MB +- Data is partitioned by time windows +- Point-in-time joins within each window +- Results are combined across windows + +### Strategy Selection Logic + +```python +def select_join_strategy(feature_size_mb, threshold_mb): + if feature_size_mb < threshold_mb: + return "broadcast" + else: + return "distributed_windowed" +``` + +## Performance Optimization + +### Automatic Optimization + +The Ray compute engine includes several automatic optimizations: + +1. **Partition Optimization**: Automatically determines optimal partition sizes +2. **Join Strategy Selection**: Chooses between broadcast and distributed joins +3. **Resource Allocation**: Scales workers based on available resources +4. **Memory Management**: Handles out-of-core processing for large datasets + +### Manual Tuning + +For specific workloads, you can fine-tune performance: + +```yaml +batch_engine: + type: ray.engine + # Fine-tuning for high-throughput scenarios + broadcast_join_threshold_mb: 200 # Larger broadcast threshold + max_parallelism_multiplier: 1 # Conservative parallelism + target_partition_size_mb: 512 # Larger partitions + window_size_for_joins: "2H" # Larger time windows +``` + +### Monitoring and Metrics + +Monitor Ray compute engine performance: + +```python +import ray + +# Check cluster resources +resources = ray.cluster_resources() +print(f"Available CPUs: {resources.get('CPU', 0)}") +print(f"Available memory: {resources.get('memory', 0) / 1e9:.2f} GB") + +# Monitor job progress +job = store.get_historical_features(...) +# Ray compute engine provides built-in progress tracking +``` + +## Integration Examples + +### With Spark Offline Store + +```yaml +# Use Ray compute engine with Spark offline store +offline_store: + type: spark + spark_conf: + spark.executor.memory: "4g" + spark.executor.cores: "2" +batch_engine: + type: ray.engine + max_workers: 8 + enable_optimization: true +``` + +### With Cloud Storage + +```yaml +# Use Ray compute engine with cloud storage +offline_store: + type: ray + storage_path: s3://my-bucket/feast-data +batch_engine: + type: ray.engine + ray_address: "ray://ray-cluster:10001" + broadcast_join_threshold_mb: 50 +``` + +### With Feature Transformations + +```python +from feast import FeatureView, Field +from feast.types import Float64 +from feast.on_demand_feature_view import on_demand_feature_view + +@on_demand_feature_view( + sources=["driver_stats"], + schema=[Field(name="trips_per_hour", dtype=Float64)] +) +def trips_per_hour(features_df): + features_df["trips_per_hour"] = features_df["avg_daily_trips"] / 24 + return features_df + +# Ray compute engine handles transformations efficiently +features = store.get_historical_features( + entity_df=entity_df, + features=["trips_per_hour:trips_per_hour"] +) +``` + +For more information, see the [Ray documentation](https://docs.ray.io/en/latest/) and [Ray Data guide](https://docs.ray.io/en/latest/data/getting-started.html). \ No newline at end of file diff --git a/docs/reference/offline-stores/ray.md b/docs/reference/offline-stores/ray.md new file mode 100644 index 00000000000..58f62c34ece --- /dev/null +++ b/docs/reference/offline-stores/ray.md @@ -0,0 +1,499 @@ +# Ray Offline Store (contrib) + +> **โš ๏ธ Contrib Plugin:** +> The Ray offline store is a contributed plugin. It may not be as stable or fully supported as core offline stores. Use with caution in production and report issues to the Feast community. + +The Ray offline store is a data I/O implementation that leverages [Ray](https://www.ray.io/) for reading and writing data from various sources. It focuses on efficient data access operations, while complex feature computation is handled by the [Ray Compute Engine](../compute-engine/ray.md). + +## Overview + +The Ray offline store provides: +- Ray-based data reading from file sources (Parquet, CSV, etc.) +- Support for both local and distributed Ray clusters +- Integration with various storage backends (local files, S3, GCS, HDFS) +- Efficient data filtering and column selection +- Timestamp-based data processing with timezone awareness + + +## Functionality Matrix + + +| Method | Supported | +|----------------------------------|-----------| +| get_historical_features | Yes | +| pull_latest_from_table_or_query | Yes | +| pull_all_from_table_or_query | Yes | +| offline_write_batch | Yes | +| write_logged_features | Yes | + + +| RetrievalJob Feature | Supported | +|----------------------------------|-----------| +| export to dataframe | Yes | +| export to arrow table | Yes | +| persist results in offline store| Yes | +| local execution of ODFVs | Yes | +| preview query plan | Yes | +| read partitioned data | Yes | + + +## โš ๏ธ Important: Resource Management + +**By default, Ray will use all available system resources (CPU and memory).** This can cause issues in test environments or when experimenting locally, potentially leading to system crashes or unresponsiveness. + +**For testing and local experimentation, we strongly recommend:** + +1. **Configure resource limits** in your `feature_store.yaml` (see [Resource Management and Testing](#resource-management-and-testing) section below) + +This will limit Ray to safe resource levels for testing and development. + + +## Architecture + +The Ray offline store follows Feast's architectural separation: +- **Ray Offline Store**: Handles data I/O operations (reading/writing data) +- **Ray Compute Engine**: Handles complex feature computation and joins +- **Clear Separation**: Each component has a single responsibility + +For complex feature processing, historical feature retrieval, and distributed joins, use the [Ray Compute Engine](../compute-engine/ray.md). + +## Configuration + +The Ray offline store can be configured in your `feature_store.yaml` file. Below are two main configuration patterns: + +### Basic Ray Offline Store + +For simple data I/O operations without distributed processing: + +```yaml +project: my_project +registry: data/registry.db +provider: local +offline_store: + type: ray + storage_path: data/ray_storage # Optional: Path for storing datasets + ray_address: localhost:10001 # Optional: Ray cluster address +``` + +### Ray Offline Store + Compute Engine + +For distributed feature processing with advanced capabilities: + +```yaml +project: my_project +registry: data/registry.db +provider: local + +# Ray offline store for data I/O operations +offline_store: + type: ray + storage_path: s3://my-bucket/feast-data # Optional: Path for storing datasets + ray_address: localhost:10001 # Optional: Ray cluster address + +# Ray compute engine for distributed feature processing +batch_engine: + type: ray.engine + + # Resource configuration + max_workers: 8 # Maximum number of Ray workers + max_parallelism_multiplier: 2 # Parallelism as multiple of CPU cores + + # Performance optimization + enable_optimization: true # Enable performance optimizations + broadcast_join_threshold_mb: 100 # Broadcast join threshold (MB) + target_partition_size_mb: 64 # Target partition size (MB) + + # Distributed join configuration + window_size_for_joins: "1H" # Time window for distributed joins + enable_distributed_joins: true # Enable distributed joins + + # Ray cluster configuration (optional) + ray_address: localhost:10001 # Ray cluster address + staging_location: s3://my-bucket/staging # Remote staging location +``` + +### Local Development Configuration + +For local development and testing: + +```yaml +project: my_local_project +registry: data/registry.db +provider: local + +offline_store: + type: ray + storage_path: ./data/ray_storage + # Conservative settings for local development + broadcast_join_threshold_mb: 25 + max_parallelism_multiplier: 1 + target_partition_size_mb: 16 + enable_ray_logging: false + # Memory constraints to prevent OOM in test/development environments + ray_conf: + num_cpus: 1 + object_store_memory: 104857600 # 100MB + _memory: 524288000 # 500MB + +batch_engine: + type: ray.engine + max_workers: 2 + enable_optimization: false +``` + +### Production Configuration + +For production deployments with distributed Ray cluster: + +```yaml +project: my_production_project +registry: s3://my-bucket/registry.db +provider: local + +offline_store: + type: ray + storage_path: s3://my-production-bucket/feast-data + ray_address: "ray://production-head-node:10001" + +batch_engine: + type: ray.engine + max_workers: 32 + max_parallelism_multiplier: 4 + enable_optimization: true + broadcast_join_threshold_mb: 50 + target_partition_size_mb: 128 + window_size_for_joins: "30min" + ray_address: "ray://production-head-node:10001" + staging_location: s3://my-production-bucket/staging +``` + +### Configuration Options + +#### Ray Offline Store Options + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `type` | string | Required | Must be `feast.offline_stores.contrib.ray_offline_store.ray.RayOfflineStore` or `ray` | +| `storage_path` | string | None | Path for storing temporary files and datasets | +| `ray_address` | string | None | Address of the Ray cluster (e.g., "localhost:10001") | +| `ray_conf` | dict | None | Ray initialization parameters for resource management (e.g., memory, CPU limits) | + +#### Ray Compute Engine Options + +For Ray compute engine configuration options, see the [Ray Compute Engine documentation](../compute-engine/ray.md#configuration-options). + +## Resource Management and Testing + +### Overview + +**By default, Ray will use all available system resources (CPU and memory).** This can cause issues in test environments or when experimenting locally, potentially leading to system crashes or unresponsiveness. + +### Resource Configuration + +For custom resource control, configure limits in your `feature_store.yaml`: + +#### Conservative Settings (Local Development/Testing) + +```yaml +offline_store: + type: ray + storage_path: ./data/ray_storage + # Resource optimization settings + broadcast_join_threshold_mb: 25 # Smaller datasets for broadcast joins + max_parallelism_multiplier: 1 # Reduced parallelism + target_partition_size_mb: 16 # Smaller partition sizes + enable_ray_logging: false # Disable verbose logging + # Memory constraints to prevent OOM in test environments + ray_conf: + num_cpus: 1 + object_store_memory: 104857600 # 100MB + _memory: 524288000 # 500MB +``` + +#### Production Settings + +```yaml +offline_store: + type: ray + storage_path: s3://my-bucket/feast-data + ray_address: "ray://production-cluster:10001" + # Optimized for production workloads + broadcast_join_threshold_mb: 100 + max_parallelism_multiplier: 2 + target_partition_size_mb: 64 + enable_ray_logging: true +``` + +### Resource Configuration Options + +| Setting | Default | Description | Testing Recommendation | +|---------|---------|-------------|------------------------| +| `broadcast_join_threshold_mb` | 100 | Size threshold for broadcast joins (MB) | 25 | +| `max_parallelism_multiplier` | 2 | Parallelism as multiple of CPU cores | 1 | +| `target_partition_size_mb` | 64 | Target partition size (MB) | 16 | +| `enable_ray_logging` | false | Enable Ray progress bars and logging | false | + +### Environment-Specific Recommendations + +#### Local Development +```yaml +# feature_store.yaml +offline_store: + type: ray + broadcast_join_threshold_mb: 25 + max_parallelism_multiplier: 1 + target_partition_size_mb: 16 +``` + +#### Production Clusters +```yaml +# feature_store.yaml +offline_store: + type: ray + ray_address: "ray://cluster-head:10001" + broadcast_join_threshold_mb: 200 + max_parallelism_multiplier: 4 +``` + +## Usage Examples + +### Basic Data Source Reading + +```python +from feast import FeatureStore, FeatureView, FileSource +from feast.types import Float32, Int64 +from datetime import timedelta + +# Define a feature view +driver_stats = FeatureView( + name="driver_stats", + entities=["driver_id"], + ttl=timedelta(days=1), + source=FileSource( + path="data/driver_stats.parquet", + timestamp_field="event_timestamp", + ), + schema=[ + ("driver_id", Int64), + ("avg_daily_trips", Float32), + ], +) + +# Initialize feature store +store = FeatureStore("feature_store.yaml") + +# The Ray offline store handles data I/O operations +# For complex feature computation, use Ray Compute Engine +``` + +### Direct Data Access + +The Ray offline store provides direct access to underlying data: + +```python +from feast.infra.offline_stores.contrib.ray_offline_store.ray import RayOfflineStore +from datetime import datetime, timedelta + +# Pull latest data from a table +job = RayOfflineStore.pull_latest_from_table_or_query( + config=store.config, + data_source=driver_stats.source, + join_key_columns=["driver_id"], + feature_name_columns=["avg_daily_trips"], + timestamp_field="event_timestamp", + created_timestamp_column=None, + start_date=datetime.now() - timedelta(days=7), + end_date=datetime.now(), +) + +# Convert to pandas DataFrame +df = job.to_df() +print(f"Retrieved {len(df)} rows") + +# Convert to Arrow Table +arrow_table = job.to_arrow() + +# Get Ray dataset directly +ray_dataset = job.to_ray_dataset() +``` + +### Batch Writing + +The Ray offline store supports batch writing for materialization: + +```python +import pyarrow as pa +from feast import FeatureView + +# Create sample data +data = pa.table({ + "driver_id": [1, 2, 3, 4, 5], + "avg_daily_trips": [10.5, 15.2, 8.7, 12.3, 9.8], + "event_timestamp": [datetime.now()] * 5 +}) + +# Write batch data +RayOfflineStore.offline_write_batch( + config=store.config, + feature_view=driver_stats, + table=data, + progress=lambda x: print(f"Wrote {x} rows") +) +``` + +### Saved Dataset Persistence + +The Ray offline store supports persisting datasets for later analysis: + +```python +from feast.infra.offline_stores.file_source import SavedDatasetFileStorage + +# Create storage destination +storage = SavedDatasetFileStorage(path="data/training_dataset.parquet") + +# Persist the dataset +job.persist(storage, allow_overwrite=False) + +# Create a saved dataset in the registry +saved_dataset = store.create_saved_dataset( + from_=job, + name="driver_training_dataset", + storage=storage, + tags={"purpose": "data_access", "version": "v1"} +) + +print(f"Saved dataset created: {saved_dataset.name}") +``` + +### Remote Storage Support + +The Ray offline store supports various remote storage backends: + +```python +# S3 storage +s3_storage = SavedDatasetFileStorage(path="s3://my-bucket/datasets/driver_features.parquet") +job.persist(s3_storage, allow_overwrite=True) + +# Google Cloud Storage +gcs_storage = SavedDatasetFileStorage(path="gs://my-project-bucket/datasets/driver_features.parquet") +job.persist(gcs_storage, allow_overwrite=True) + +# HDFS +hdfs_storage = SavedDatasetFileStorage(path="hdfs://namenode:8020/datasets/driver_features.parquet") +job.persist(hdfs_storage, allow_overwrite=True) +``` + +### Using Ray Cluster + +To use Ray in cluster mode for distributed data access: + +1. Start a Ray cluster: +```bash +ray start --head --port=10001 +``` + +2. Configure your `feature_store.yaml`: +```yaml +offline_store: + type: ray + ray_address: localhost:10001 + storage_path: s3://my-bucket/features +``` + +3. For multiple worker nodes: +```bash +# On worker nodes +ray start --address='head-node-ip:10001' +``` + +### Data Source Validation + +The Ray offline store validates data sources to ensure compatibility: + +```python +from feast.infra.offline_stores.contrib.ray_offline_store.ray import RayOfflineStore + +# Validate a data source +try: + RayOfflineStore.validate_data_source(store.config, driver_stats.source) + print("Data source is valid") +except Exception as e: + print(f"Data source validation failed: {e}") +``` + +## Limitations + +The Ray offline store has the following limitations: + +1. **File Sources Only**: Currently supports only `FileSource` data sources +2. **No Direct SQL**: Does not support SQL query interfaces +3. **No Online Writes**: Cannot write directly to online stores +4. **No Complex Transformations**: The Ray offline store focuses on data I/O operations. For complex feature transformations (aggregations, joins, custom UDFs), use the [Ray Compute Engine](../compute-engine/ray.md) instead + +## Integration with Ray Compute Engine + +For complex feature processing operations, use the Ray offline store in combination with the [Ray Compute Engine](../compute-engine/ray.md). See the **Ray Offline Store + Compute Engine** configuration example in the [Configuration](#configuration) section above for a complete setup. + + +For more advanced troubleshooting, refer to the [Ray documentation](https://docs.ray.io/en/latest/data/getting-started.html). + +## Quick Reference + +### Configuration Templates + +**Basic Ray Offline Store** (local development): +```yaml +offline_store: + type: ray + storage_path: ./data/ray_storage + # Conservative settings for local development + broadcast_join_threshold_mb: 25 + max_parallelism_multiplier: 1 + target_partition_size_mb: 16 + enable_ray_logging: false +``` + +**Ray Offline Store + Compute Engine** (distributed processing): +```yaml +offline_store: + type: ray + storage_path: s3://my-bucket/feast-data + +batch_engine: + type: ray.engine + max_workers: 8 + enable_optimization: true + broadcast_join_threshold_mb: 100 +``` + +### Key Commands + +```python +# Initialize feature store +store = FeatureStore("feature_store.yaml") + +# Get historical features (uses compute engine if configured) +features = store.get_historical_features(entity_df=df, features=["fv:feature"]) + +# Direct data access (uses offline store) +job = RayOfflineStore.pull_latest_from_table_or_query(...) +df = job.to_df() + +# Offline write batch (materialization) +# Create sample data for materialization +data = pa.table({ + "driver_id": [1, 2, 3, 4, 5], + "avg_daily_trips": [10.5, 15.2, 8.7, 12.3, 9.8], + "event_timestamp": [datetime.now()] * 5 +}) + +# Write batch to offline store +RayOfflineStore.offline_write_batch( + config=store.config, + feature_view=driver_stats_fv, + table=data, + progress=lambda rows: print(f"Processed {rows} rows") +) +``` + +For complete examples, see the [Configuration](#configuration) section above. \ No newline at end of file diff --git a/infra/feast-operator/api/v1alpha1/featurestore_types.go b/infra/feast-operator/api/v1alpha1/featurestore_types.go index 8587fc98240..756c2e17ab1 100644 --- a/infra/feast-operator/api/v1alpha1/featurestore_types.go +++ b/infra/feast-operator/api/v1alpha1/featurestore_types.go @@ -315,7 +315,7 @@ var ValidOfflineStoreFilePersistenceTypes = []string{ // OfflineStoreDBStorePersistence configures the DB store persistence for the offline store service type OfflineStoreDBStorePersistence struct { // Type of the persistence type you want to use. - // +kubebuilder:validation:Enum=snowflake.offline;bigquery;redshift;spark;postgres;trino;athena;mssql;couchbase.offline;clickhouse + // +kubebuilder:validation:Enum=snowflake.offline;bigquery;redshift;spark;postgres;trino;athena;mssql;couchbase.offline;clickhouse;ray Type string `json:"type"` // Data store parameters should be placed as-is from the "feature_store.yaml" under the secret key. "registry_type" & "type" fields should be removed. SecretRef corev1.LocalObjectReference `json:"secretRef"` @@ -334,6 +334,7 @@ var ValidOfflineStoreDBStorePersistenceTypes = []string{ "mssql", "couchbase.offline", "clickhouse", + "ray", } // OnlineStore configures the online store service diff --git a/infra/feast-operator/bundle/manifests/feast-operator.clusterserviceversion.yaml b/infra/feast-operator/bundle/manifests/feast-operator.clusterserviceversion.yaml index 4d585d0feac..fcc382974f9 100644 --- a/infra/feast-operator/bundle/manifests/feast-operator.clusterserviceversion.yaml +++ b/infra/feast-operator/bundle/manifests/feast-operator.clusterserviceversion.yaml @@ -50,7 +50,7 @@ metadata: } ] capabilities: Basic Install - createdAt: "2025-07-21T20:53:09Z" + createdAt: "2025-07-25T09:58:54Z" operators.operatorframework.io/builder: operator-sdk-v1.38.0 operators.operatorframework.io/project_layout: go.kubebuilder.io/v4 name: feast-operator.v0.51.0 diff --git a/infra/feast-operator/bundle/manifests/feast.dev_featurestores.yaml b/infra/feast-operator/bundle/manifests/feast.dev_featurestores.yaml index b7718e57104..701ed9bf052 100644 --- a/infra/feast-operator/bundle/manifests/feast.dev_featurestores.yaml +++ b/infra/feast-operator/bundle/manifests/feast.dev_featurestores.yaml @@ -842,6 +842,7 @@ spec: - mssql - couchbase.offline - clickhouse + - ray type: string required: - secretRef @@ -4806,6 +4807,7 @@ spec: - mssql - couchbase.offline - clickhouse + - ray type: string required: - secretRef diff --git a/infra/feast-operator/config/crd/bases/feast.dev_featurestores.yaml b/infra/feast-operator/config/crd/bases/feast.dev_featurestores.yaml index b2fed6992d5..360fbba5453 100644 --- a/infra/feast-operator/config/crd/bases/feast.dev_featurestores.yaml +++ b/infra/feast-operator/config/crd/bases/feast.dev_featurestores.yaml @@ -842,6 +842,7 @@ spec: - mssql - couchbase.offline - clickhouse + - ray type: string required: - secretRef @@ -4806,6 +4807,7 @@ spec: - mssql - couchbase.offline - clickhouse + - ray type: string required: - secretRef diff --git a/infra/feast-operator/dist/install.yaml b/infra/feast-operator/dist/install.yaml index b79489f7a29..add5c13c9a5 100644 --- a/infra/feast-operator/dist/install.yaml +++ b/infra/feast-operator/dist/install.yaml @@ -850,6 +850,7 @@ spec: - mssql - couchbase.offline - clickhouse + - ray type: string required: - secretRef @@ -4814,6 +4815,7 @@ spec: - mssql - couchbase.offline - clickhouse + - ray type: string required: - secretRef diff --git a/pyproject.toml b/pyproject.toml index 1ee36c3a102..357a20fcf17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,7 @@ rag = [ "transformers>=4.36.0", "datasets>=3.6.0", ] +ray = ["ray>=2.47.0"] redis = [ "redis>=4.2.2,<5", "hiredis>=2.0.0,<3", @@ -167,7 +168,7 @@ ci = [ "types-setuptools", "types-tabulate", "virtualenv<20.24.2", - "feast[aws, azure, cassandra, clickhouse, couchbase, delta, docling, duckdb, elasticsearch, faiss, gcp, ge, go, grpcio, hazelcast, hbase, ibis, ikv, k8s, mcp, milvus, mssql, mysql, opentelemetry, spark, trino, postgres, pytorch, qdrant, rag, redis, singlestore, snowflake, sqlite_vec]" + "feast[aws, azure, cassandra, clickhouse, couchbase, delta, docling, duckdb, elasticsearch, faiss, gcp, ge, go, grpcio, hazelcast, hbase, ibis, ikv, k8s, mcp, milvus, mssql, mysql, opentelemetry, spark, trino, postgres, pytorch, qdrant, rag, ray, redis, singlestore, snowflake, sqlite_vec]" ] nlp = ["feast[docling, milvus, pytorch, rag]"] dev = ["feast[ci]"] diff --git a/sdk/python/feast/feature_view.py b/sdk/python/feast/feature_view.py index b77737a8bd5..4a086f1b99b 100644 --- a/sdk/python/feast/feature_view.py +++ b/sdk/python/feast/feature_view.py @@ -510,7 +510,8 @@ def _from_proto_internal( if feature_view_proto.spec.ttl.ToNanoseconds() == 0 else feature_view_proto.spec.ttl.ToTimedelta() ), - source=batch_source if batch_source else source_views, + source=source_views if source_views else batch_source, + sink_source=batch_source if source_views else None, ) if stream_source: feature_view.stream_source = stream_source diff --git a/sdk/python/feast/feature_view_utils.py b/sdk/python/feast/feature_view_utils.py new file mode 100644 index 00000000000..daf28e09dec --- /dev/null +++ b/sdk/python/feast/feature_view_utils.py @@ -0,0 +1,229 @@ +""" +Utility functions for feature view operations including source resolution. +""" + +import logging +import typing +from dataclasses import dataclass +from typing import Callable, Optional + +if typing.TYPE_CHECKING: + from feast.data_source import DataSource + from feast.feature_view import FeatureView + from feast.repo_config import RepoConfig + +logger = logging.getLogger(__name__) + + +@dataclass +class FeatureViewSourceInfo: + """Information about a feature view's data source resolution.""" + + data_source: "DataSource" + source_type: str + has_transformation: bool + transformation_func: Optional[Callable] = None + source_description: str = "" + + +def has_transformation(feature_view: "FeatureView") -> bool: + """Check if a feature view has transformations (UDF or feature_transformation).""" + return ( + getattr(feature_view, "udf", None) is not None + or getattr(feature_view, "feature_transformation", None) is not None + ) + + +def get_transformation_function(feature_view: "FeatureView") -> Optional[Callable]: + """Extract the transformation function from a feature view.""" + feature_transformation = getattr(feature_view, "feature_transformation", None) + if feature_transformation: + # Use feature_transformation if available (preferred) + if hasattr(feature_transformation, "udf") and callable( + feature_transformation.udf + ): + return feature_transformation.udf + + # Fallback to direct UDF + udf = getattr(feature_view, "udf", None) + if udf and callable(udf): + return udf + + return None + + +def find_original_source_view(feature_view: "FeatureView") -> "FeatureView": + """ + Recursively find the original source feature view that has a batch_source. + For derived feature views, this follows the source_views chain until it finds + a feature view with an actual DataSource (batch_source). + """ + current_view = feature_view + while hasattr(current_view, "source_views") and current_view.source_views: + if not current_view.source_views: + break + current_view = current_view.source_views[0] # Assuming single source for now + return current_view + + +def check_sink_source_exists(data_source: "DataSource") -> bool: + """ + Check if a sink_source file actually exists. + Args: + data_source: The DataSource to check + Returns: + bool: True if the source exists, False otherwise + """ + try: + import fsspec + + # Get the source path + if hasattr(data_source, "path"): + source_path = data_source.path + else: + source_path = str(data_source) + + fs, path_in_fs = fsspec.core.url_to_fs(source_path) + return fs.exists(path_in_fs) + except Exception as e: + logger.warning(f"Failed to check if source exists: {e}") + return False + + +def resolve_feature_view_source( + feature_view: "FeatureView", + config: Optional["RepoConfig"] = None, + is_materialization: bool = False, +) -> FeatureViewSourceInfo: + """ + Resolve the appropriate data source for a feature view. + + This handles the complex logic of determining whether to read from: + 1. sink_source (materialized data from parent views) + 2. batch_source (original data source) + 3. Recursive resolution for derived views + + Args: + feature_view: The feature view to resolve + config: Repository configuration (optional) + is_materialization: Whether this is during materialization (affects derived view handling) + + Returns: + FeatureViewSourceInfo: Information about the resolved source + """ + view_has_transformation = has_transformation(feature_view) + transformation_func = ( + get_transformation_function(feature_view) if view_has_transformation else None + ) + + # Check if this is a derived feature view (has source_views) + is_derived_view = ( + hasattr(feature_view, "source_views") and feature_view.source_views + ) + + if not is_derived_view: + # Regular feature view - use its batch_source directly + return FeatureViewSourceInfo( + data_source=feature_view.batch_source, + source_type="batch_source", + has_transformation=view_has_transformation, + transformation_func=transformation_func, + source_description=f"Direct batch_source for {feature_view.name}", + ) + + # This is a derived feature view - need to resolve parent source + if not feature_view.source_views: + raise ValueError( + f"Derived feature view {feature_view.name} has no source_views" + ) + parent_view = feature_view.source_views[0] # Assuming single source for now + + # For derived views: distinguish between materialization and historical retrieval + if ( + hasattr(parent_view, "sink_source") + and parent_view.sink_source + and is_materialization + ): + # During materialization, try to use sink_source if it exists + if check_sink_source_exists(parent_view.sink_source): + logger.debug( + f"Materialization: Using parent {parent_view.name} sink_source" + ) + return FeatureViewSourceInfo( + data_source=parent_view.sink_source, + source_type="sink_source", + has_transformation=view_has_transformation, + transformation_func=transformation_func, + source_description=f"Parent {parent_view.name} sink_source for derived view {feature_view.name}", + ) + else: + logger.info( + f"Parent {parent_view.name} sink_source doesn't exist during materialization" + ) + + # Check if parent is also a derived view first - if so, recursively resolve to original source + if hasattr(parent_view, "source_views") and parent_view.source_views: + # Parent is also a derived view - recursively find original source + original_source_view = find_original_source_view(parent_view) + return FeatureViewSourceInfo( + data_source=original_source_view.batch_source, + source_type="original_source", + has_transformation=view_has_transformation, + transformation_func=transformation_func, + source_description=f"Original source {original_source_view.name} batch_source for derived view {feature_view.name} (via {parent_view.name})", + ) + elif hasattr(parent_view, "batch_source") and parent_view.batch_source: + # Parent has a direct batch_source, use it + return FeatureViewSourceInfo( + data_source=parent_view.batch_source, + source_type="batch_source", + has_transformation=view_has_transformation, + transformation_func=transformation_func, + source_description=f"Parent {parent_view.name} batch_source for derived view {feature_view.name}", + ) + else: + # No valid source found + raise ValueError( + f"Unable to resolve data source for derived feature view {feature_view.name} via parent {parent_view.name}" + ) + + +def resolve_feature_view_source_with_fallback( + feature_view: "FeatureView", + config: Optional["RepoConfig"] = None, + is_materialization: bool = False, +) -> FeatureViewSourceInfo: + """ + Resolve feature view source with fallback error handling. + + This version includes additional error handling and fallback logic + for cases where the primary resolution fails. + """ + try: + return resolve_feature_view_source(feature_view, config, is_materialization) + except Exception as e: + logger.warning(f"Primary source resolution failed for {feature_view.name}: {e}") + + # Fallback: try to find any available source + if hasattr(feature_view, "batch_source") and feature_view.batch_source: + return FeatureViewSourceInfo( + data_source=feature_view.batch_source, + source_type="fallback_batch_source", + has_transformation=has_transformation(feature_view), + transformation_func=get_transformation_function(feature_view), + source_description=f"Fallback batch_source for {feature_view.name}", + ) + elif hasattr(feature_view, "source_views") and feature_view.source_views: + # Try the original source view as last resort + original_view = find_original_source_view(feature_view) + return FeatureViewSourceInfo( + data_source=original_view.batch_source, + source_type="fallback_original_source", + has_transformation=has_transformation(feature_view), + transformation_func=get_transformation_function(feature_view), + source_description=f"Fallback original source {original_view.name} for {feature_view.name}", + ) + else: + raise ValueError( + f"Unable to resolve any data source for feature view {feature_view.name}" + ) diff --git a/sdk/python/feast/infra/compute_engines/dag/model.py b/sdk/python/feast/infra/compute_engines/dag/model.py index f77fdd0b6c9..5990eea6141 100644 --- a/sdk/python/feast/infra/compute_engines/dag/model.py +++ b/sdk/python/feast/infra/compute_engines/dag/model.py @@ -5,3 +5,4 @@ class DAGFormat(str, Enum): SPARK = "spark" PANDAS = "pandas" ARROW = "arrow" + RAY = "ray" diff --git a/sdk/python/feast/infra/compute_engines/local/compute.py b/sdk/python/feast/infra/compute_engines/local/compute.py index 341b20dee02..556468f5e1d 100644 --- a/sdk/python/feast/infra/compute_engines/local/compute.py +++ b/sdk/python/feast/infra/compute_engines/local/compute.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, Union +from typing import Literal, Optional, Sequence, Union from feast import ( BatchFeatureView, @@ -22,6 +22,17 @@ LocalRetrievalJob, ) from feast.infra.registry.base_registry import BaseRegistry +from feast.repo_config import FeastConfigBaseModel + + +class LocalComputeEngineConfig(FeastConfigBaseModel): + """Configuration for Local Compute Engine.""" + + type: Literal["local"] = "local" + """Local Compute Engine type selector""" + + backend: Optional[str] = None + """Backend to use for DataFrame operations (e.g., 'pandas', 'polars')""" class LocalComputeEngine(ComputeEngine): diff --git a/sdk/python/feast/infra/compute_engines/ray/__init__.py b/sdk/python/feast/infra/compute_engines/ray/__init__.py new file mode 100644 index 00000000000..7b02d0ca615 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/ray/__init__.py @@ -0,0 +1,38 @@ +""" +Ray Compute Engine for Feast + +This module provides a Ray-based compute engine for distributed feature computation. +It includes: +- RayComputeEngine: Main compute engine implementation +- RayComputeEngineConfig: Configuration for the compute engine +- Ray DAG nodes for distributed processing +""" + +from .compute import RayComputeEngine +from .config import RayComputeEngineConfig +from .feature_builder import RayFeatureBuilder +from .job import RayDAGRetrievalJob, RayMaterializationJob +from .nodes import ( + RayAggregationNode, + RayDedupNode, + RayFilterNode, + RayJoinNode, + RayReadNode, + RayTransformationNode, + RayWriteNode, +) + +__all__ = [ + "RayComputeEngine", + "RayComputeEngineConfig", + "RayDAGRetrievalJob", + "RayMaterializationJob", + "RayFeatureBuilder", + "RayReadNode", + "RayJoinNode", + "RayFilterNode", + "RayAggregationNode", + "RayDedupNode", + "RayTransformationNode", + "RayWriteNode", +] diff --git a/sdk/python/feast/infra/compute_engines/ray/compute.py b/sdk/python/feast/infra/compute_engines/ray/compute.py new file mode 100644 index 00000000000..7bf7e15dfb0 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/ray/compute.py @@ -0,0 +1,286 @@ +import logging +from datetime import datetime +from typing import Sequence, Union + +import ray + +from feast import ( + BatchFeatureView, + Entity, + FeatureView, + OnDemandFeatureView, + StreamFeatureView, +) +from feast.infra.common.materialization_job import ( + MaterializationJob, + MaterializationJobStatus, + MaterializationTask, +) +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.base import ComputeEngine +from feast.infra.compute_engines.ray.config import RayComputeEngineConfig +from feast.infra.compute_engines.ray.feature_builder import RayFeatureBuilder +from feast.infra.compute_engines.ray.job import ( + RayDAGRetrievalJob, + RayMaterializationJob, +) +from feast.infra.offline_stores.offline_store import RetrievalJob +from feast.infra.registry.base_registry import BaseRegistry + +logger = logging.getLogger(__name__) + + +class RayComputeEngine(ComputeEngine): + """ + Ray-based compute engine for distributed feature computation. + This engine uses Ray for distributed processing of features, enabling + efficient point-in-time joins, aggregations, and transformations across + large datasets. + """ + + def __init__( + self, + offline_store, + online_store, + repo_config, + **kwargs, + ): + super().__init__( + offline_store=offline_store, + online_store=online_store, + repo_config=repo_config, + **kwargs, + ) + self.config = repo_config.batch_engine + assert isinstance(self.config, RayComputeEngineConfig) + self._ensure_ray_initialized() + + def _ensure_ray_initialized(self): + """Ensure Ray is initialized with proper configuration.""" + if not ray.is_initialized(): + if self.config.ray_address: + ray.init( + address=self.config.ray_address, + ignore_reinit_error=True, + include_dashboard=False, + ) + else: + ray_init_args = { + "ignore_reinit_error": True, + "include_dashboard": False, + } + + # Add configuration from ray_conf if provided + if self.config.ray_conf: + ray_init_args.update(self.config.ray_conf) + + ray.init(**ray_init_args) + + # Configure Ray context for optimal performance + from ray.data.context import DatasetContext + + ctx = DatasetContext.get_current() + ctx.enable_tensor_extension_casting = False + + # Log Ray cluster information + cluster_resources = ray.cluster_resources() + logger.info( + f"Ray cluster initialized with {cluster_resources.get('CPU', 0)} CPUs, " + f"{cluster_resources.get('memory', 0) / (1024**3):.1f}GB memory" + ) + + def update( + self, + project: str, + views_to_delete: Sequence[ + Union[BatchFeatureView, StreamFeatureView, FeatureView] + ], + views_to_keep: Sequence[ + Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] + ], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + ): + """Ray compute engine doesn't require infrastructure updates.""" + pass + + def teardown_infra( + self, + project: str, + fvs: Sequence[Union[BatchFeatureView, StreamFeatureView, FeatureView]], + entities: Sequence[Entity], + ): + """Ray compute engine doesn't require infrastructure teardown.""" + pass + + def _materialize_one( + self, + registry: BaseRegistry, + task: MaterializationTask, + from_offline_store: bool = False, + **kwargs, + ) -> MaterializationJob: + """Materialize features for a single feature view.""" + job_id = f"{task.feature_view.name}-{task.start_time}-{task.end_time}" + + if from_offline_store: + logger.warning( + "Materializing from offline store will be deprecated. " + "Please use the new materialization API." + ) + return self._materialize_from_offline_store( + registry=registry, + feature_view=task.feature_view, + start_date=task.start_time, + end_date=task.end_time, + project=task.project, + ) + + try: + # Build typed execution context + context = self.get_execution_context(registry, task) + + # Construct Feature Builder and execute + builder = RayFeatureBuilder(registry, task.feature_view, task, self.config) + plan = builder.build() + result = plan.execute(context) + + # Log execution results + logger.info(f"Materialization completed for {task.feature_view.name}") + + return RayMaterializationJob( + job_id=job_id, + status=MaterializationJobStatus.SUCCEEDED, + result=result, + ) + + except Exception as e: + logger.error(f"Materialization failed for {task.feature_view.name}: {e}") + return RayMaterializationJob( + job_id=job_id, + status=MaterializationJobStatus.ERROR, + error=e, + ) + + def _materialize_from_offline_store( + self, + registry: BaseRegistry, + feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], + start_date: datetime, + end_date: datetime, + project: str, + ) -> MaterializationJob: + """Legacy materialization method for backward compatibility.""" + from feast.utils import _get_column_names + + job_id = f"{feature_view.name}-{start_date}-{end_date}" + + try: + # Get column information + entities = [ + registry.get_entity(name, project) for name in feature_view.entities + ] + ( + join_key_columns, + feature_name_columns, + timestamp_field, + created_timestamp_column, + ) = _get_column_names(feature_view, entities) + + # Pull data from offline store + retrieval_job = self.offline_store.pull_latest_from_table_or_query( + config=self.repo_config, + data_source=feature_view.batch_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, + ) + + # Convert to Arrow Table and write to online/offline stores + arrow_table = retrieval_job.to_arrow() + + # Write to online store if enabled + if getattr(feature_view, "online", False): + # TODO: Implement proper online store writing with correct data format conversion + logger.debug( + "Online store writing not implemented yet for Ray compute engine" + ) + + # Write to offline store if enabled (this handles sink_source automatically for derived views) + if getattr(feature_view, "offline", False): + self.offline_store.offline_write_batch( + config=self.repo_config, + feature_view=feature_view, + table=arrow_table, + progress=lambda x: None, + ) + + # For derived views, also ensure data is written to sink_source if it exists + # This is critical for feature view chaining to work properly + sink_source = getattr(feature_view, "sink_source", None) + if sink_source is not None: + logger.debug( + f"Writing derived view {feature_view.name} to sink_source: {sink_source.path}" + ) + + # Write to sink_source using Ray data + try: + ray_dataset = ray.data.from_arrow(arrow_table) + ray_dataset.write_parquet(sink_source.path) + except Exception as e: + logger.error( + f"Failed to write to sink_source {sink_source.path}: {e}" + ) + return RayMaterializationJob( + job_id=job_id, + status=MaterializationJobStatus.SUCCEEDED, + ) + + except Exception as e: + logger.error(f"Legacy materialization failed: {e}") + return RayMaterializationJob( + job_id=job_id, + status=MaterializationJobStatus.ERROR, + error=e, + ) + + def get_historical_features( + self, registry: BaseRegistry, task: HistoricalRetrievalTask + ) -> RetrievalJob: + """Get historical features using Ray DAG execution.""" + if isinstance(task.entity_df, str): + raise NotImplementedError( + "SQL-based entity_df is not yet supported in Ray DAG" + ) + + try: + # Build typed execution context + context = self.get_execution_context(registry, task) + + # Construct Feature Builder and build execution plan + builder = RayFeatureBuilder(registry, task.feature_view, task, self.config) + plan = builder.build() + + return RayDAGRetrievalJob( + plan=plan, + context=context, + config=self.repo_config, + full_feature_names=task.full_feature_name, + on_demand_feature_views=getattr(task, "on_demand_feature_views", None), + feature_refs=getattr(task, "feature_refs", None), + ) + + except Exception as e: + logger.error(f"Historical feature retrieval failed: {e}") + return RayDAGRetrievalJob( + plan=None, + context=None, + config=self.repo_config, + full_feature_names=task.full_feature_name, + on_demand_feature_views=getattr(task, "on_demand_feature_views", None), + feature_refs=getattr(task, "feature_refs", None), + error=e, + ) diff --git a/sdk/python/feast/infra/compute_engines/ray/config.py b/sdk/python/feast/infra/compute_engines/ray/config.py new file mode 100644 index 00000000000..c6d74d262dd --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/ray/config.py @@ -0,0 +1,66 @@ +"""Configuration for Ray compute engine.""" + +from datetime import timedelta +from typing import Any, Dict, Literal, Optional + +from pydantic import StrictStr + +from feast.repo_config import FeastConfigBaseModel + + +class RayComputeEngineConfig(FeastConfigBaseModel): + """Configuration for Ray Compute Engine.""" + + type: Literal["ray.engine"] = "ray.engine" + """Ray Compute Engine type selector""" + + ray_address: Optional[str] = None + """Ray cluster address. If None, uses local Ray cluster.""" + + staging_location: Optional[StrictStr] = None + """Remote path for batch materialization jobs""" + + # Ray-specific performance configurations + broadcast_join_threshold_mb: int = 100 + """Threshold for using broadcast joins (in MB)""" + + enable_distributed_joins: bool = True + """Whether to enable distributed joins for large datasets""" + + max_parallelism_multiplier: int = 2 + """Multiplier for max parallelism based on available CPUs""" + + target_partition_size_mb: int = 64 + """Target partition size in MB""" + + window_size_for_joins: str = "1H" + """Window size for windowed temporal joins""" + + ray_conf: Optional[Dict[str, Any]] = None + """Ray configuration parameters""" + + # Additional configuration options + max_workers: Optional[int] = None + """Maximum number of Ray workers. If None, uses all available cores.""" + + enable_optimization: bool = True + """Enable automatic performance optimizations.""" + + execution_timeout_seconds: Optional[int] = None + """Timeout for job execution in seconds.""" + + @property + def window_size_timedelta(self) -> timedelta: + """Convert window size string to timedelta.""" + if self.window_size_for_joins.endswith("H"): + hours = int(self.window_size_for_joins[:-1]) + return timedelta(hours=hours) + elif self.window_size_for_joins.endswith("min"): + minutes = int(self.window_size_for_joins[:-3]) + return timedelta(minutes=minutes) + elif self.window_size_for_joins.endswith("s"): + seconds = int(self.window_size_for_joins[:-1]) + return timedelta(seconds=seconds) + else: + # Default to 1 hour + return timedelta(hours=1) diff --git a/sdk/python/feast/infra/compute_engines/ray/feature_builder.py b/sdk/python/feast/infra/compute_engines/ray/feature_builder.py new file mode 100644 index 00000000000..07c5c6f1113 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/ray/feature_builder.py @@ -0,0 +1,317 @@ +import logging +from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast + +from feast import FeatureView +from feast.infra.common.materialization_job import MaterializationTask +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.algorithms.topo import topological_sort +from feast.infra.compute_engines.dag.node import DAGNode +from feast.infra.compute_engines.dag.plan import ExecutionPlan +from feast.infra.compute_engines.feature_builder import FeatureBuilder +from feast.infra.compute_engines.ray.config import RayComputeEngineConfig +from feast.infra.compute_engines.ray.nodes import ( + RayAggregationNode, + RayDedupNode, + RayDerivedReadNode, + RayFilterNode, + RayJoinNode, + RayReadNode, + RayTransformationNode, + RayWriteNode, +) + +if TYPE_CHECKING: + from feast.infra.compute_engines.ray.config import RayComputeEngineConfig + +logger = logging.getLogger(__name__) + + +class RayFeatureBuilder(FeatureBuilder): + """ + Ray-specific feature builder that constructs execution plans using Ray DAG nodes. + This builder translates FeatureView definitions into Ray-optimized execution DAGs + that can leverage distributed computing for large-scale feature processing. + """ + + def __init__( + self, + registry, + feature_view, + task: Union[MaterializationTask, HistoricalRetrievalTask], + config: "RayComputeEngineConfig", + ): + super().__init__(registry, feature_view, task) + self.config = config + self.is_historical_retrieval = isinstance(task, HistoricalRetrievalTask) + self.is_materialization = isinstance(task, MaterializationTask) + + def build_source_node(self, view): + """Build the source node for reading feature data.""" + data_source = getattr(view, "batch_source", None) or getattr( + view, "source", None + ) + start_time = self.task.start_time + end_time = self.task.end_time + column_info = self.get_column_info(view) + + node = RayReadNode( + name="source", + source=data_source, + column_info=column_info, + config=self.config, + start_time=start_time, + end_time=end_time, + ) + + self.nodes.append(node) + return node + + def build_aggregation_node(self, view, input_node: DAGNode) -> DAGNode: + """Build aggregation node for Ray.""" + agg_specs = getattr(view, "aggregations", []) + if not agg_specs: + raise ValueError(f"No aggregations found for {view.name}") + + group_by_keys = view.entities + timestamp_col = getattr(view.batch_source, "timestamp_field", "event_timestamp") + + node = RayAggregationNode( + name="aggregation", + aggregations=agg_specs, + group_by_keys=group_by_keys, + timestamp_col=timestamp_col, + config=self.config, + ) + node.add_input(input_node) + + self.nodes.append(node) + return node + + def build_join_node(self, view, input_nodes): + """Build the join node for combining multiple feature sources.""" + column_info = self.get_column_info(view) + + node = RayJoinNode( + name="join", + column_info=column_info, + config=self.config, + is_historical_retrieval=self.is_historical_retrieval, + ) + for input_node in input_nodes: + node.add_input(input_node) + + self.nodes.append(node) + return node + + def build_filter_node(self, view, input_node): + """Build the filter node for TTL and custom filtering.""" + ttl = getattr(view, "ttl", None) + filter_condition = getattr(view, "filter", None) + column_info = self.get_column_info(view) + + node = RayFilterNode( + name="filter", + column_info=column_info, + config=self.config, + ttl=ttl, + filter_condition=filter_condition, + ) + node.add_input(input_node) + + self.nodes.append(node) + return node + + def build_dedup_node(self, view, input_node): + """Build the deduplication node for removing duplicate records.""" + column_info = self.get_column_info(view) + + node = RayDedupNode( + name="dedup", + column_info=column_info, + config=self.config, + ) + node.add_input(input_node) + + self.nodes.append(node) + return node + + def build_transformation_node(self, view, input_nodes): + """Build the transformation node for user-defined transformations.""" + feature_transformation = getattr(view, "feature_transformation", None) + udf = getattr(view, "udf", None) + + transformation = feature_transformation or udf + if not transformation: + raise ValueError(f"No feature transformation found for {view.name}") + + node = RayTransformationNode( + name="transformation", + transformation=transformation, + config=self.config, + ) + for input_node in input_nodes: + node.add_input(input_node) + + self.nodes.append(node) + return node + + def build_output_nodes(self, view, final_node): + """Build the output node for writing processed features.""" + node = RayWriteNode( + name="output", + feature_view=view, + inputs=[final_node], + ) + + self.nodes.append(node) + return node + + def build_validation_node(self, view, input_node): + """Build the validation node for feature validation.""" + # TODO: Implement validation logic + logger.warning( + "Feature validation is not yet implemented for Ray compute engine." + ) + return input_node + + def _build(self, view, input_nodes: Optional[List[DAGNode]]) -> DAGNode: + has_physical_source = (hasattr(view, "batch_source") and view.batch_source) or ( + hasattr(view, "source") + and view.source + and not isinstance(view.source, FeatureView) + ) + + is_derived_view = hasattr(view, "source_views") and view.source_views + + if has_physical_source and not is_derived_view: + last_node = self.build_source_node(view) + if self._should_transform(view): + last_node = self.build_transformation_node(view, [last_node]) + elif input_nodes: + if self._should_transform(view): + last_node = self.build_transformation_node(view, input_nodes) + else: + last_node = self.build_join_node(view, input_nodes) + else: + raise ValueError(f"FeatureView {view.name} has no valid source or inputs") + + if last_node is None: + raise ValueError(f"Failed to build processing node for {view.name}") + + last_node = self.build_filter_node(view, last_node) + + if self._should_aggregate(view): + last_node = self.build_aggregation_node(view, last_node) + elif self._should_dedupe(view): + last_node = self.build_dedup_node(view, last_node) + + if self._should_validate(view): + last_node = self.build_validation_node(view, last_node) + + return last_node + + def build(self) -> ExecutionPlan: + """Build execution plan with support for derived feature views and sink_source writing.""" + if self.is_historical_retrieval and self._should_aggregate(self.feature_view): + return self._build_aggregation_optimized_plan() + + if self.is_materialization: + return self._build_materialization_plan() + + return super().build() + + def _build_materialization_plan(self) -> ExecutionPlan: + """Build execution plan for materialization with intermediate sink writes.""" + logger.info(f"Building materialization plan for {self.feature_view.name}") + + # Step 1: Topo sort the FeatureViewNode DAG (Logical DAG) + logical_nodes = self.feature_resolver.topological_sort(self.dag_root) + logger.info( + f"Logical nodes in topo order: {[node.view.name for node in logical_nodes]}" + ) + + # Step 2: For each FeatureView, build its corresponding execution DAGNode and write node + # Build them in dependency order to ensure proper execution + view_to_write_node: Dict[str, RayWriteNode] = {} + + for i, logical_node in enumerate(logical_nodes): + view = logical_node.view + logger.info( + f"Building nodes for view {view.name} (step {i + 1}/{len(logical_nodes)})" + ) + + # For derived views, we need to ensure parent views are materialized first + # So we create a processing chain that depends on parent write nodes + parent_write_nodes = [] + processing_node: DAGNode + if hasattr(view, "source_views") and view.source_views: + # This is a derived view - collect parent write nodes as dependencies + for parent in logical_node.inputs: + if parent.view.name in view_to_write_node: + parent_write_nodes.append(view_to_write_node[parent.view.name]) + + if parent_write_nodes: + derived_read_node = RayDerivedReadNode( + name=f"{view.name}:derived_read", + feature_view=view, + parent_dependencies=cast(List[DAGNode], parent_write_nodes), + config=self.config, + column_info=self.get_column_info(view), + is_materialization=self.is_materialization, + ) + self.nodes.append(derived_read_node) + + # Then build the rest of the processing chain (filter, aggregate, etc.) + processing_node = self._build(view, [derived_read_node]) + else: + # Parent not yet built - this shouldn't happen in topo order + raise ValueError(f"Parent views for {view.name} not yet built") + else: + # Regular view - build normal processing chain + processing_node = self._build(view, None) + + # Create a write node for this view + write_node = RayWriteNode( + name=f"{view.name}:write", + feature_view=view, + inputs=[processing_node], + ) + + view_to_write_node[view.name] = write_node + logger.info(f"Created write node for {view.name}") + + # Step 3: The final write node is the one for the top-level feature view + final_node = view_to_write_node[self.feature_view.name] + + # Step 4: Topo sort the final DAG from the output node (Physical DAG) + sorted_nodes = topological_sort(final_node) + + # Step 5: Update self.nodes to include all nodes for the execution plan + self.nodes = sorted_nodes + + # Step 6: Return sorted execution plan + return ExecutionPlan(sorted_nodes) + + def _build_aggregation_optimized_plan(self) -> ExecutionPlan: + """Build execution plan optimized for aggregation scenarios.""" + + # 1. Read source data + last_node = self.build_source_node(self.feature_view) + + # 2. Apply filters (TTL, custom filters) BEFORE aggregation + last_node = self.build_filter_node(self.feature_view, last_node) + + # 3. Aggregate across all historical records + last_node = self.build_aggregation_node(self.feature_view, last_node) + + # 4. Join with entity_df to get aggregated features for each entity + last_node = self.build_join_node(self.feature_view, [last_node]) + + # 5. Apply transformations to aggregated features + if self._should_transform(self.feature_view): + last_node = self.build_transformation_node(self.feature_view, [last_node]) + + # 6. Output + last_node = self.build_output_nodes(self.feature_view, last_node) + + return ExecutionPlan(self.nodes) diff --git a/sdk/python/feast/infra/compute_engines/ray/job.py b/sdk/python/feast/infra/compute_engines/ray/job.py new file mode 100644 index 00000000000..b2e88f1d5c5 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/ray/job.py @@ -0,0 +1,297 @@ +import logging +import uuid +from dataclasses import dataclass +from typing import List, Optional + +import pandas as pd +import pyarrow as pa +import ray +from ray.data import Dataset + +from feast import OnDemandFeatureView +from feast.dqm.errors import ValidationFailed +from feast.errors import SavedDatasetLocationAlreadyExists +from feast.infra.common.materialization_job import ( + MaterializationJob, + MaterializationJobStatus, +) +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.model import DAGFormat +from feast.infra.compute_engines.dag.plan import ExecutionPlan +from feast.infra.compute_engines.dag.value import DAGValue +from feast.infra.offline_stores.file_source import SavedDatasetFileStorage +from feast.infra.offline_stores.offline_store import RetrievalJob, RetrievalMetadata +from feast.repo_config import RepoConfig +from feast.saved_dataset import SavedDatasetStorage + +logger = logging.getLogger(__name__) + + +class RayDAGRetrievalJob(RetrievalJob): + """ + Ray-based retrieval job that executes a DAG plan to retrieve historical features. + """ + + def __init__( + self, + plan: Optional[ExecutionPlan], + context: Optional[ExecutionContext], + config: RepoConfig, + full_feature_names: bool, + on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None, + feature_refs: Optional[List[str]] = None, + metadata: Optional[RetrievalMetadata] = None, + error: Optional[BaseException] = None, + ): + super().__init__() + self._plan = plan + self._context = context + self._config = config + self._full_feature_names = full_feature_names + self._on_demand_feature_views = on_demand_feature_views or [] + self._feature_refs = feature_refs or [] + self._metadata = metadata + self._error = error + self._result_dataset: Optional[Dataset] = None + self._result_df: Optional[pd.DataFrame] = None + self._result_arrow: Optional[pa.Table] = None + + def error(self) -> Optional[BaseException]: + """Return any error that occurred during job execution.""" + return self._error + + def _ensure_executed(self) -> DAGValue: + """Ensure the execution plan has been executed.""" + if self._result_dataset is None and self._plan and self._context: + try: + result = self._plan.execute(self._context) + if hasattr(result, "data") and isinstance(result.data, Dataset): + self._result_dataset = result.data + else: + # If result is not a Ray Dataset, convert it + if isinstance(result.data, pd.DataFrame): + self._result_dataset = ray.data.from_pandas(result.data) + elif isinstance(result.data, pa.Table): + self._result_dataset = ray.data.from_arrow(result.data) + else: + raise ValueError( + f"Unsupported result type: {type(result.data)}" + ) + return result + except Exception as e: + self._error = e + logger.error(f"Ray DAG execution failed: {e}") + raise + elif self._result_dataset is None: + raise ValueError("No execution plan available or execution failed") + + # Return a mock DAGValue for compatibility + return DAGValue(data=self._result_dataset, format=DAGFormat.RAY) + + def to_ray_dataset(self) -> Dataset: + """Get the result as a Ray Dataset.""" + self._ensure_executed() + assert self._result_dataset is not None, ( + "Dataset should not be None after execution" + ) + return self._result_dataset + + def to_df( + self, + validation_reference=None, + timeout: Optional[int] = None, + ) -> pd.DataFrame: + """Convert the result to a pandas DataFrame.""" + if self._result_df is None: + if self.on_demand_feature_views: + # Use parent implementation for ODFV processing + logger.info( + f"Processing {len(self.on_demand_feature_views)} on-demand feature views" + ) + self._result_df = super().to_df( + validation_reference=validation_reference, timeout=timeout + ) + else: + # Direct conversion from Ray Dataset + self._ensure_executed() + assert self._result_dataset is not None, ( + "Dataset should not be None after execution" + ) + self._result_df = self._result_dataset.to_pandas() + + # Handle validation if provided + if validation_reference: + try: + validation_result = validation_reference.profile.validate( + self._result_df + ) + if not validation_result.is_success: + raise ValidationFailed(validation_result) + except ImportError: + logger.warning("DQM profiler not available, skipping validation") + except Exception as e: + logger.error(f"Validation failed: {e}") + raise ValueError(f"Data validation failed: {e}") + + return self._result_df + + def to_arrow( + self, + validation_reference=None, + timeout: Optional[int] = None, + ) -> pa.Table: + """Convert the result to an Arrow Table.""" + if self._result_arrow is None: + if self.on_demand_feature_views: + # Use parent implementation for ODFV processing + self._result_arrow = super().to_arrow( + validation_reference=validation_reference, timeout=timeout + ) + else: + # Direct conversion from Ray Dataset + self._ensure_executed() + assert self._result_dataset is not None, ( + "Dataset should not be None after execution" + ) + self._result_arrow = self._result_dataset.to_pandas().to_arrow() + + # Handle validation if provided + if validation_reference: + try: + df = self._result_arrow.to_pandas() + validation_result = validation_reference.profile.validate(df) + if not validation_result.is_success: + raise ValidationFailed(validation_result) + except ImportError: + logger.warning("DQM profiler not available, skipping validation") + except Exception as e: + logger.error(f"Validation failed: {e}") + raise ValueError(f"Data validation failed: {e}") + + return self._result_arrow + + def to_remote_storage(self) -> list[str]: + """Write the result to remote storage.""" + if not self._config.batch_engine.staging_location: + raise ValueError("Staging location must be set for remote storage") + + try: + self._ensure_executed() + assert self._result_dataset is not None, ( + "Dataset should not be None after execution" + ) + output_uri = ( + f"{self._config.batch_engine.staging_location}/{str(uuid.uuid4())}" + ) + self._result_dataset.write_parquet(output_uri) + logger.debug(f"Wrote result to {output_uri}") + return [output_uri] + except Exception as e: + raise RuntimeError(f"Failed to write to remote storage: {e}") + + def persist( + self, + storage: SavedDatasetStorage, + allow_overwrite: bool = False, + timeout: Optional[int] = None, + ) -> str: + """Persist the result to the specified storage.""" + if not isinstance(storage, SavedDatasetFileStorage): + raise ValueError( + f"Ray compute engine only supports SavedDatasetFileStorage, got {type(storage)}" + ) + + destination_path = storage.file_options.uri + + # Check if destination already exists + if not destination_path.startswith(("s3://", "gs://", "hdfs://")): + import os + + if not allow_overwrite and os.path.exists(destination_path): + raise SavedDatasetLocationAlreadyExists(location=destination_path) + os.makedirs(os.path.dirname(destination_path), exist_ok=True) + + try: + self._ensure_executed() + assert self._result_dataset is not None, ( + "Dataset should not be None after execution" + ) + self._result_dataset.write_parquet(destination_path) + return destination_path + except Exception as e: + raise RuntimeError(f"Failed to persist dataset to {destination_path}: {e}") + + def to_sql(self) -> str: + """Generate SQL representation of the execution plan.""" + if self._plan and self._context: + return self._plan.to_sql(self._context) + raise NotImplementedError("SQL generation not available without execution plan") + + @property + def full_feature_names(self) -> bool: + return self._full_feature_names + + @property + def on_demand_feature_views(self) -> List[OnDemandFeatureView]: + return self._on_demand_feature_views + + @property + def metadata(self) -> Optional[RetrievalMetadata]: + return self._metadata + + def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: + """Internal method to get DataFrame (used by parent class).""" + self._ensure_executed() + assert self._result_dataset is not None, ( + "Dataset should not be None after execution" + ) + return self._result_dataset.to_pandas() + + def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table: + """Internal method to get Arrow Table (used by parent class).""" + self._ensure_executed() + assert self._result_dataset is not None, ( + "Dataset should not be None after execution" + ) + return self._result_dataset.to_pandas().to_arrow() + + +@dataclass +class RayMaterializationJob(MaterializationJob): + """ + Ray-based materialization job that tracks the status of feature materialization. + """ + + def __init__( + self, + job_id: str, + status: MaterializationJobStatus, + result: Optional[DAGValue] = None, + error: Optional[BaseException] = None, + ): + super().__init__() + self._job_id = job_id + self._status = status + self._result = result + self._error = error + + def job_id(self) -> str: + return self._job_id + + def status(self) -> MaterializationJobStatus: + return self._status + + def error(self) -> Optional[BaseException]: + return self._error + + def should_be_retried(self) -> bool: + """Ray jobs are generally not retried by default.""" + return False + + def url(self) -> Optional[str]: + """Ray jobs don't have a specific URL.""" + return None + + def result(self) -> Optional[DAGValue]: + """Get the result of the materialization job.""" + return self._result diff --git a/sdk/python/feast/infra/compute_engines/ray/nodes.py b/sdk/python/feast/infra/compute_engines/ray/nodes.py new file mode 100644 index 00000000000..5a5f04acee3 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/ray/nodes.py @@ -0,0 +1,706 @@ +import logging +from datetime import datetime, timedelta, timezone +from typing import List, Optional, Union + +import dill +import pandas as pd +import pyarrow as pa +import ray +from ray.data import Dataset + +from feast import BatchFeatureView, FeatureView, StreamFeatureView +from feast.aggregation import Aggregation +from feast.data_source import DataSource +from feast.feature_view_utils import get_transformation_function, has_transformation +from feast.infra.common.serde import SerializedArtifacts +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.model import DAGFormat +from feast.infra.compute_engines.dag.node import DAGNode +from feast.infra.compute_engines.dag.value import DAGValue +from feast.infra.compute_engines.ray.config import RayComputeEngineConfig +from feast.infra.compute_engines.utils import create_offline_store_retrieval_job +from feast.infra.ray_shared_utils import ( + apply_field_mapping, + broadcast_join, + distributed_windowed_join, +) + +logger = logging.getLogger(__name__) + +# Entity timestamp alias for historical feature retrieval +ENTITY_TS_ALIAS = "__entity_event_timestamp" + + +class RayReadNode(DAGNode): + """ + Ray node for reading data from offline stores. + """ + + def __init__( + self, + name: str, + source: DataSource, + column_info, + config: RayComputeEngineConfig, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + ): + super().__init__(name) + self.source = source + self.column_info = column_info + self.config = config + self.start_time = start_time + self.end_time = end_time + + def execute(self, context: ExecutionContext) -> DAGValue: + """Execute the read operation to load data from the offline store.""" + try: + retrieval_job = create_offline_store_retrieval_job( + data_source=self.source, + column_info=self.column_info, + context=context, + start_time=self.start_time, + end_time=self.end_time, + ) + + if hasattr(retrieval_job, "to_ray_dataset"): + ray_dataset = retrieval_job.to_ray_dataset() + else: + try: + arrow_table = retrieval_job.to_arrow() + ray_dataset = ray.data.from_arrow(arrow_table) + except Exception: + df = retrieval_job.to_df() + ray_dataset = ray.data.from_pandas(df) + + field_mapping = getattr(self.source, "field_mapping", None) + if field_mapping: + ray_dataset = apply_field_mapping(ray_dataset, field_mapping) + + return DAGValue( + data=ray_dataset, + format=DAGFormat.RAY, + metadata={ + "source": "offline_store", + "source_type": type(self.source).__name__, + "start_time": self.start_time, + "end_time": self.end_time, + }, + ) + + except Exception as e: + logger.error(f"Ray read node failed: {e}") + raise + + +class RayJoinNode(DAGNode): + """ + Ray node for joining entity dataframes with feature data. + """ + + def __init__( + self, + name: str, + column_info, + config: RayComputeEngineConfig, + is_historical_retrieval: bool = False, + ): + super().__init__(name) + self.column_info = column_info + self.config = config + self.is_historical_retrieval = is_historical_retrieval + + def execute(self, context: ExecutionContext) -> DAGValue: + """Execute the join operation.""" + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.RAY) + feature_dataset: Dataset = input_value.data + + # If this is not a historical retrieval, just return the feature data + if not self.is_historical_retrieval or context.entity_df is None: + return DAGValue( + data=feature_dataset, + format=DAGFormat.RAY, + metadata={"joined": False}, + ) + + entity_df = context.entity_df + if isinstance(entity_df, pd.DataFrame): + entity_dataset = ray.data.from_pandas(entity_df) + else: + entity_dataset = entity_df + + join_keys = self.column_info.join_keys + timestamp_col = self.column_info.timestamp_column + requested_feats = getattr(self.column_info, "feature_cols", []) + + # Check if the feature dataset contains aggregated features (from aggregation node) + # If so, we don't need point-in-time join logic - just simple join on entity keys + is_aggregated = ( + input_value.metadata.get("aggregated", False) + if input_value.metadata + else False + ) + + feature_size = feature_dataset.size_bytes() + + if is_aggregated: + # For aggregated features, do simple join on entity keys + feature_df = feature_dataset.to_pandas() + feature_ref = ray.put(feature_df) + + def join_with_aggregated_features(batch: pd.DataFrame) -> pd.DataFrame: + if batch.empty: + return batch + features = ray.get(feature_ref) + if join_keys: + result = pd.merge( + batch, + features, + on=join_keys, + how="left", + suffixes=("", "_feature"), + ) + else: + result = batch.copy() + return result + + joined_dataset = entity_dataset.map_batches( + join_with_aggregated_features, batch_format="pandas" + ) + else: + if feature_size <= self.config.broadcast_join_threshold_mb * 1024 * 1024: + # Use broadcast join for small feature datasets + joined_dataset = broadcast_join( + entity_dataset, + feature_dataset.to_pandas(), + join_keys, + timestamp_col, + requested_feats, + ) + else: + # Use distributed join for large datasets + joined_dataset = distributed_windowed_join( + entity_dataset, + feature_dataset, + join_keys, + timestamp_col, + requested_feats, + ) + + return DAGValue( + data=joined_dataset, + format=DAGFormat.RAY, + metadata={ + "joined": True, + "join_keys": join_keys, + "join_strategy": "broadcast" + if feature_size <= self.config.broadcast_join_threshold_mb * 1024 * 1024 + else "distributed", + }, + ) + + +class RayFilterNode(DAGNode): + """ + Ray node for filtering data based on TTL and custom conditions. + """ + + def __init__( + self, + name: str, + column_info, + config: RayComputeEngineConfig, + ttl: Optional[timedelta] = None, + filter_condition: Optional[str] = None, + ): + super().__init__(name) + self.column_info = column_info + self.config = config + self.ttl = ttl + self.filter_condition = filter_condition + + def execute(self, context: ExecutionContext) -> DAGValue: + """Execute the filter operation.""" + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.RAY) + dataset: Dataset = input_value.data + + def apply_filters(batch: pd.DataFrame) -> pd.DataFrame: + """Apply TTL and custom filters to the batch.""" + if batch.empty: + return batch + + filtered_batch = batch.copy() + + # Apply TTL filter if specified + if self.ttl: + timestamp_col = self.column_info.timestamp_column + if timestamp_col in filtered_batch.columns: + # Convert to datetime if not already + if not pd.api.types.is_datetime64_any_dtype( + filtered_batch[timestamp_col] + ): + filtered_batch[timestamp_col] = pd.to_datetime( + filtered_batch[timestamp_col] + ) + + # For historical retrieval, use entity timestamp for TTL calculation + if ENTITY_TS_ALIAS in filtered_batch.columns: + # Use entity timestamp for TTL calculation (historical retrieval) + if not pd.api.types.is_datetime64_any_dtype( + filtered_batch[ENTITY_TS_ALIAS] + ): + filtered_batch[ENTITY_TS_ALIAS] = pd.to_datetime( + filtered_batch[ENTITY_TS_ALIAS] + ) + + # Apply TTL filter with both upper and lower bounds: + # 1. feature.ts <= entity.event_timestamp (upper bound) + # 2. feature.ts >= entity.event_timestamp - ttl (lower bound) + upper_bound = filtered_batch[ENTITY_TS_ALIAS] + lower_bound = filtered_batch[ENTITY_TS_ALIAS] - self.ttl + + filtered_batch = filtered_batch[ + (filtered_batch[timestamp_col] <= upper_bound) + & (filtered_batch[timestamp_col] >= lower_bound) + ] + else: + # Use current time for TTL calculation (real-time retrieval) + # Check if timestamp column is timezone-aware + if pd.api.types.is_datetime64tz_dtype( + filtered_batch[timestamp_col] + ): + # Use timezone-aware current time + current_time = datetime.now(timezone.utc) + else: + # Use naive datetime + current_time = datetime.now() + + ttl_threshold = current_time - self.ttl + + # Apply TTL filter + filtered_batch = filtered_batch[ + filtered_batch[timestamp_col] >= ttl_threshold + ] + + # Apply custom filter condition if specified + if self.filter_condition: + try: + filtered_batch = filtered_batch.query(self.filter_condition) + except Exception as e: + logger.warning(f"Custom filter failed: {e}") + + return filtered_batch + + filtered_dataset = dataset.map_batches(apply_filters, batch_format="pandas") + + return DAGValue( + data=filtered_dataset, + format=DAGFormat.RAY, + metadata={ + "filtered": True, + "ttl": self.ttl, + "filter_condition": self.filter_condition, + }, + ) + + +class RayAggregationNode(DAGNode): + """ + Ray node for performing aggregations on feature data. + """ + + def __init__( + self, + name: str, + aggregations: List[Aggregation], + group_by_keys: List[str], + timestamp_col: str, + config: RayComputeEngineConfig, + ): + super().__init__(name) + self.aggregations = aggregations + self.group_by_keys = group_by_keys + self.timestamp_col = timestamp_col + self.config = config + + def execute(self, context: ExecutionContext) -> DAGValue: + """Execute the aggregation operation.""" + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.RAY) + dataset: Dataset = input_value.data + + # Convert aggregations to Ray's groupby format + agg_dict = {} + for agg in self.aggregations: + feature_name = f"{agg.function}_{agg.column}" + if agg.time_window: + feature_name += f"_{int(agg.time_window.total_seconds())}s" + + if agg.function == "count": + agg_dict[feature_name] = (agg.column, "count") + elif agg.function == "sum": + agg_dict[feature_name] = (agg.column, "sum") + elif agg.function == "mean" or agg.function == "avg": + agg_dict[feature_name] = (agg.column, "mean") + elif agg.function == "min": + agg_dict[feature_name] = (agg.column, "min") + elif agg.function == "max": + agg_dict[feature_name] = (agg.column, "max") + elif agg.function == "std": + agg_dict[feature_name] = (agg.column, "std") + elif agg.function == "var": + agg_dict[feature_name] = (agg.column, "var") + else: + raise ValueError(f"Unknown aggregation function: {agg.function}.") + + # Apply aggregations using pandas fallback (Ray's native groupby has compatibility issues) + if self.group_by_keys and agg_dict: + # Use pandas-based aggregation for entire dataset + aggregated_dataset = self._fallback_pandas_aggregation(dataset, agg_dict) + else: + # No group keys or aggregations, return original dataset + aggregated_dataset = dataset + + return DAGValue( + data=aggregated_dataset, + format=DAGFormat.RAY, + metadata={ + "aggregated": True, + "aggregations": len(self.aggregations), + "group_by_keys": self.group_by_keys, + }, + ) + + def _fallback_pandas_aggregation(self, dataset: Dataset, agg_dict: dict) -> Dataset: + """Fallback to pandas-based aggregation for the entire dataset.""" + # Convert entire dataset to pandas for aggregation + df = dataset.to_pandas() + + if df.empty: + return dataset + + # Group by the specified keys + if self.group_by_keys: + grouped = df.groupby(self.group_by_keys) + else: + # If no group keys, apply aggregations to entire dataset + grouped = df.groupby(lambda x: 0) # Dummy grouping + + # Apply each aggregation + agg_results = [] + for feature_name, (column, function) in agg_dict.items(): + if column in df.columns: + if function == "count": + result = grouped[column].count() + elif function == "sum": + result = grouped[column].sum() + elif function == "mean": + result = grouped[column].mean() + elif function == "min": + result = grouped[column].min() + elif function == "max": + result = grouped[column].max() + elif function == "std": + result = grouped[column].std() + elif function == "var": + result = grouped[column].var() + else: + raise ValueError(f"Unknown aggregation function: {function}.") + + result.name = feature_name + agg_results.append(result) + + # Combine aggregation results + if agg_results: + result_df = pd.concat(agg_results, axis=1) + + # Reset index to make group keys regular columns + if self.group_by_keys: + result_df = result_df.reset_index() + + # Convert back to Ray Dataset + return ray.data.from_pandas(result_df) + else: + return dataset + + +class RayDedupNode(DAGNode): + """ + Ray node for deduplicating records. + """ + + def __init__( + self, + name: str, + column_info, + config: RayComputeEngineConfig, + ): + super().__init__(name) + self.column_info = column_info + self.config = config + + def execute(self, context: ExecutionContext) -> DAGValue: + """Execute the deduplication operation.""" + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.RAY) + dataset: Dataset = input_value.data + + def deduplicate_batch(batch: pd.DataFrame) -> pd.DataFrame: + """Remove duplicates from the batch.""" + if batch.empty: + return batch + + # Get deduplication keys + join_keys = self.column_info.join_keys + timestamp_col = self.column_info.timestamp_column + + if join_keys: + # Sort by join keys and timestamp (most recent first) + sort_columns = join_keys + [timestamp_col] + available_columns = [ + col for col in sort_columns if col in batch.columns + ] + + if available_columns: + # Sort and deduplicate + sorted_batch = batch.sort_values( + available_columns, + ascending=[True] * len(join_keys) + + [False], # Recent timestamps first + ) + + # Keep first occurrence (most recent) for each join key combination + deduped_batch = sorted_batch.drop_duplicates( + subset=join_keys, + keep="first", + ) + + return deduped_batch + + return batch + + deduped_dataset = dataset.map_batches(deduplicate_batch, batch_format="pandas") + + return DAGValue( + data=deduped_dataset, + format=DAGFormat.RAY, + metadata={"deduped": True}, + ) + + +class RayTransformationNode(DAGNode): + """ + Ray node for applying feature transformations. + """ + + def __init__( + self, + name: str, + transformation, + config: RayComputeEngineConfig, + ): + super().__init__(name) + self.transformation = transformation + self.transformation_name = getattr(transformation, "name", "unknown") + self.config = config + + def execute(self, context: ExecutionContext) -> DAGValue: + """Execute the transformation operation.""" + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.RAY) + dataset: Dataset = input_value.data + + transformation_serialized = None + if hasattr(self.transformation, "udf") and callable(self.transformation.udf): + transformation_serialized = dill.dumps(self.transformation.udf) + elif callable(self.transformation): + transformation_serialized = dill.dumps(self.transformation) + + def apply_transformation_with_serialized_udf( + batch: pd.DataFrame, + ) -> pd.DataFrame: + """Apply the transformation using pre-serialized UDF.""" + if batch.empty: + return batch + + try: + if transformation_serialized: + transformation_func = dill.loads(transformation_serialized) + transformed_batch = transformation_func(batch) + else: + logger.warning( + "No serialized transformation available, returning original batch" + ) + transformed_batch = batch + + return transformed_batch + except Exception as e: + logger.error(f"Transformation failed: {e}") + return batch + + transformed_dataset = dataset.map_batches( + apply_transformation_with_serialized_udf, batch_format="pandas" + ) + + return DAGValue( + data=transformed_dataset, + format=DAGFormat.RAY, + metadata={ + "transformed": True, + "transformation": self.transformation_name, + }, + ) + + +class RayDerivedReadNode(DAGNode): + """ + Ray node for reading derived feature views after parent dependencies are materialized. + This node ensures that parent feature views are fully materialized before reading from their sink_source. + """ + + def __init__( + self, + name: str, + feature_view: FeatureView, + parent_dependencies: List[DAGNode], + config: RayComputeEngineConfig, + column_info, + is_materialization: bool = True, + ): + super().__init__(name) + self.feature_view = feature_view + self.config = config + self.column_info = column_info + self.is_materialization = is_materialization + + # Add parent dependencies to ensure they execute first + for parent in parent_dependencies: + self.add_input(parent) + + def execute(self, context: ExecutionContext) -> DAGValue: + """Execute the derived read operation after parents are materialized.""" + parent_values = self.get_input_values(context) + + if not parent_values: + raise ValueError( + f"No parent data available for derived view {self.feature_view.name}" + ) + + parent_value = parent_values[0] + parent_value.assert_format(DAGFormat.RAY) + + if has_transformation(self.feature_view): + transformation_func = get_transformation_function(self.feature_view) + if callable(transformation_func): + + def apply_transformation(batch: pd.DataFrame) -> pd.DataFrame: + return transformation_func(batch) + + transformed_dataset = parent_value.data.map_batches( + apply_transformation + ) + return DAGValue( + data=transformed_dataset, + format=DAGFormat.RAY, + metadata={ + "source": "derived_from_parent", + "source_description": f"Transformed data from parent for {self.feature_view.name}", + }, + ) + + return DAGValue( + data=parent_value.data, + format=DAGFormat.RAY, + metadata={ + "source": "derived_from_parent", + "source_description": f"Data from parent for {self.feature_view.name}", + }, + ) + + +class RayWriteNode(DAGNode): + """ + Ray node for writing results to online/offline stores and sink_source paths. + This node handles writing intermediate results for derived feature views. + """ + + def __init__( + self, + name: str, + feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], + inputs=None, + ): + super().__init__(name, inputs=inputs) + self.feature_view = feature_view + + def execute(self, context: ExecutionContext) -> DAGValue: + """Execute the write operation.""" + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.RAY) + dataset: Dataset = input_value.data + + serialized_artifacts = SerializedArtifacts.serialize( + feature_view=self.feature_view, repo_config=context.repo_config + ) + + def write_batch_with_serialized_artifacts(batch: pd.DataFrame) -> pd.DataFrame: + """Write each batch using pre-serialized artifacts.""" + if batch.empty: + return batch + + try: + ( + feature_view, + online_store, + offline_store, + repo_config, + ) = serialized_artifacts.unserialize() + + arrow_table = pa.Table.from_pandas(batch) + + # Write to online store if enabled + if getattr(feature_view, "online", False): + # TODO: Implement proper online store writing with correct data format conversion + logger.debug( + "Online store writing not implemented yet for Ray compute engine" + ) + + # Write to offline store if enabled + if getattr(feature_view, "offline", False): + try: + offline_store.offline_write_batch( + config=repo_config, + feature_view=feature_view, + table=arrow_table, + progress=lambda x: None, + ) + except Exception as e: + logger.error(f"Failed to write to offline store: {e}") + raise + + return batch + + except Exception as e: + logger.error(f"Write operation failed: {e}") + raise + + written_dataset = dataset.map_batches( + write_batch_with_serialized_artifacts, batch_format="pandas" + ) + written_dataset = written_dataset.materialize() + + return DAGValue( + data=written_dataset, + format=DAGFormat.RAY, + metadata={ + "written": True, + "feature_view": self.feature_view.name, + "online": getattr(self.feature_view, "online", False), + "offline": getattr(self.feature_view, "offline", False), + "batch_source_path": getattr( + getattr(self.feature_view, "batch_source", None), "path", "unknown" + ), + }, + ) diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/__init__.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/__init__.py new file mode 100644 index 00000000000..d0eb96bfcb2 --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/__init__.py @@ -0,0 +1,56 @@ +""" +Ray offline store for Feast. + +This module provides distributed offline feature store functionality using Ray with +advanced optimization features for scalable feature retrieval. + +Key Features: +- Intelligent join strategy selection (broadcast vs. distributed) +- Resource-aware partitioning and parallelism +- Windowed temporal joins for large datasets +- Configurable performance tuning parameters +- Automatic cluster resource management + +Classes: +- RayOfflineStore: Main offline store implementation +- RayOfflineStoreConfig: Configuration with optimization settings +- RayRetrievalJob: Enhanced retrieval job with caching +- RayResourceManager: Cluster resource management +- RayDataProcessor: Optimized data processing operations + +Usage: +Configure in your feature_store.yaml: +```yaml +offline_store: + type: ray + storage_path: /path/to/storage + broadcast_join_threshold_mb: 100 + enable_distributed_joins: true + max_parallelism_multiplier: 2 + target_partition_size_mb: 64 + window_size_for_joins: "1H" +``` + +Performance Optimizations: +- Broadcast joins for small datasets (<100MB by default) +- Distributed windowed joins for large datasets +- Optimal partitioning based on cluster resources +- Memory-aware buffer sizing +- Lazy evaluation with caching +""" + +from .ray import ( + RayDataProcessor, + RayOfflineStore, + RayOfflineStoreConfig, + RayResourceManager, + RayRetrievalJob, +) + +__all__ = [ + "RayOfflineStore", + "RayOfflineStoreConfig", + "RayRetrievalJob", + "RayResourceManager", + "RayDataProcessor", +] diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py new file mode 100644 index 00000000000..8a82ec24a64 --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py @@ -0,0 +1,2107 @@ +import logging +import os +import uuid +from datetime import datetime +from pathlib import Path +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import dill +import fsspec +import numpy as np +import pandas as pd +import pyarrow as pa +import ray +import ray.data +from ray.data import Dataset +from ray.data.context import DatasetContext + +from feast.data_source import DataSource +from feast.errors import ( + RequestDataNotFoundInEntityDfException, + SavedDatasetLocationAlreadyExists, +) +from feast.feature_logging import LoggingConfig, LoggingSource +from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView +from feast.feature_view_utils import resolve_feature_view_source_with_fallback +from feast.infra.offline_stores.file_source import ( + FileLoggingDestination, + FileSource, + SavedDatasetFileStorage, +) +from feast.infra.offline_stores.offline_store import ( + OfflineStore, + RetrievalJob, + RetrievalMetadata, +) +from feast.infra.offline_stores.offline_utils import ( + get_entity_df_timestamp_bounds, + get_pyarrow_schema_from_batch_source, + infer_event_timestamp_from_entity_df, +) +from feast.infra.ray_shared_utils import ( + _build_required_columns, + apply_field_mapping, + ensure_timestamp_compatibility, + normalize_timestamp_columns, +) +from feast.infra.registry.base_registry import BaseRegistry +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.repo_config import FeastConfigBaseModel, RepoConfig +from feast.saved_dataset import SavedDatasetStorage, ValidationReference +from feast.type_map import ( + convert_array_column, + convert_scalar_column, + feast_value_type_to_pandas_type, + pa_to_feast_value_type, +) +from feast.utils import _get_column_names, make_df_tzaware, make_tzaware + +logger = logging.getLogger(__name__) + + +def _get_data_schema_info( + data: Union[pd.DataFrame, Dataset], +) -> Tuple[Dict[str, Any], List[str]]: + """ + Extract schema information from DataFrame or Dataset. + Args: + data: DataFrame or Ray Dataset + Returns: + Tuple of (dtypes_dict, column_names) + """ + if isinstance(data, Dataset): + schema = data.schema() + dtypes = {} + for i, col in enumerate(schema.names): + field_type = schema.field(i).type + try: + pa_type_str = str(field_type).lower() + feast_value_type = pa_to_feast_value_type(pa_type_str) + pandas_type_str = feast_value_type_to_pandas_type(feast_value_type) + dtypes[col] = pd.api.types.pandas_dtype(pandas_type_str) + except Exception: + dtypes[col] = pd.api.types.pandas_dtype("object") + columns = schema.names + else: + dtypes = data.dtypes.to_dict() + columns = list(data.columns) + return dtypes, columns + + +def _apply_to_data( + data: Union[pd.DataFrame, Dataset], + process_func: Callable[[pd.DataFrame], pd.DataFrame], + inplace: bool = False, +) -> Union[pd.DataFrame, Dataset]: + """ + Apply a processing function to DataFrame or Dataset. + Args: + data: DataFrame or Ray Dataset to process + process_func: Function that takes a DataFrame and returns a processed DataFrame + inplace: Whether to modify DataFrame in place (only applies to pandas) + Returns: + Processed DataFrame or Dataset + """ + if isinstance(data, Dataset): + return data.map_batches(process_func, batch_format="pandas") + else: + if not inplace: + data = data.copy() + return process_func(data) + + +def _handle_empty_dataframe_case( + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_columns: List[str], +) -> pd.DataFrame: + """ + Handle empty DataFrame case by creating properly structured empty DataFrame. + Args: + join_key_columns: List of join key columns + feature_name_columns: List of feature columns + timestamp_columns: List of timestamp columns + Returns: + Empty DataFrame with proper structure and column types + """ + empty_columns = _build_required_columns( + join_key_columns, feature_name_columns, timestamp_columns + ) + df = pd.DataFrame(columns=empty_columns) + for col in timestamp_columns: + if col in df.columns: + df[col] = df[col].astype("datetime64[ns, UTC]") + return df + + +def _safe_infer_event_timestamp_column( + data: Union[pd.DataFrame, Dataset], fallback_column: str = "event_timestamp" +) -> str: + """ + Safely infer the event timestamp column. + Works with both pandas DataFrames and Ray Datasets. + Args: + data: DataFrame or Ray Dataset to analyze + fallback_column: Default column name to use if inference fails + Returns: + Inferred or fallback timestamp column name + """ + try: + dtypes, _ = _get_data_schema_info(data) + return infer_event_timestamp_from_entity_df(dtypes) + except Exception as e: + logger.debug( + f"Timestamp column inference failed: {e}, using fallback: {fallback_column}" + ) + return fallback_column + + +def _safe_get_entity_timestamp_bounds( + data: Union[pd.DataFrame, Dataset], timestamp_column: str +) -> Tuple[Optional[datetime], Optional[datetime]]: + """ + Safely get entity timestamp bounds. + Works with both pandas DataFrames and Ray Datasets. + Args: + data: DataFrame or Ray Dataset + timestamp_column: Name of timestamp column + Returns: + Tuple of (min_timestamp, max_timestamp) or (None, None) if failed + """ + try: + if isinstance(data, Dataset): + min_ts = data.min(timestamp_column) + max_ts = data.max(timestamp_column) + else: + if timestamp_column in data.columns: + min_ts, max_ts = get_entity_df_timestamp_bounds(data, timestamp_column) + else: + return None, None + if hasattr(min_ts, "to_pydatetime"): + min_ts = min_ts.to_pydatetime() + elif isinstance(min_ts, pd.Timestamp): + min_ts = min_ts.to_pydatetime() + if hasattr(max_ts, "to_pydatetime"): + max_ts = max_ts.to_pydatetime() + elif isinstance(max_ts, pd.Timestamp): + max_ts = max_ts.to_pydatetime() + return min_ts, max_ts + except Exception as e: + logger.debug( + f"Timestamp bounds extraction failed: {e}, falling back to manual calculation" + ) + try: + if isinstance(data, Dataset): + + def extract_bounds(batch: pd.DataFrame) -> pd.DataFrame: + if timestamp_column in batch.columns and not batch.empty: + timestamps = pd.to_datetime(batch[timestamp_column], utc=True) + return pd.DataFrame( + {"min_ts": [timestamps.min()], "max_ts": [timestamps.max()]} + ) + return pd.DataFrame({"min_ts": [None], "max_ts": [None]}) + + bounds_ds = data.map_batches(extract_bounds, batch_format="pandas") + bounds_df = bounds_ds.to_pandas() + + if not bounds_df.empty: + min_ts = bounds_df["min_ts"].min() + max_ts = bounds_df["max_ts"].max() + + if pd.notna(min_ts) and pd.notna(max_ts): + return min_ts.to_pydatetime(), max_ts.to_pydatetime() + else: + if timestamp_column in data.columns: + timestamps = pd.to_datetime(data[timestamp_column], utc=True) + return ( + timestamps.min().to_pydatetime(), + timestamps.max().to_pydatetime(), + ) + except Exception: + pass + + return None, None + + +def _safe_validate_schema( + config: RepoConfig, + data_source: DataSource, + table_columns: List[str], + operation_name: str = "operation", +) -> Optional[Tuple[pa.Schema, List[str]]]: + """ + Safely validate schema using offline_utils with graceful fallback. + Args: + config: Repo configuration + data_source: Data source to validate against + table_columns: Actual table column names + operation_name: Name of operation for logging + Returns: + Tuple of (expected_schema, expected_columns) or None if validation fails + """ + try: + expected_schema, expected_columns = get_pyarrow_schema_from_batch_source( + config, data_source + ) + if set(expected_columns) != set(table_columns): + logger.warning( + f"Schema mismatch in {operation_name}:\n" + f" Expected columns: {expected_columns}\n" + f" Actual columns: {table_columns}" + ) + if set(expected_columns) == set(table_columns): + logger.info(f"Columns match but order differs for {operation_name}") + return expected_schema, expected_columns + else: + logger.debug(f"Schema validation passed for {operation_name}") + return expected_schema, expected_columns + + except Exception as e: + logger.warning( + f"Schema validation skipped for {operation_name} due to error: {e}" + ) + logger.debug("Schema validation error details:", exc_info=True) + return None + + +def _convert_feature_column_types( + data: Union[pd.DataFrame, Dataset], feature_views: List[FeatureView] +) -> Union[pd.DataFrame, Dataset]: + """ + Convert feature columns to appropriate pandas types using Feast's type mapping utilities. + Works with both pandas DataFrames and Ray Datasets. + Args: + data: DataFrame or Ray Dataset containing feature data + feature_views: List of feature views with type information + Returns: + DataFrame or Dataset with properly converted feature column types + """ + + def convert_batch(batch: pd.DataFrame) -> pd.DataFrame: + batch = batch.copy() + + for fv in feature_views: + for feature in fv.features: + feat_name = feature.name + if feat_name not in batch.columns: + continue + try: + value_type = feature.dtype.to_value_type() + if value_type.name.endswith("_LIST"): + batch[feat_name] = convert_array_column( + batch[feat_name], value_type + ) + else: + target_pandas_type = feast_value_type_to_pandas_type(value_type) + batch[feat_name] = convert_scalar_column( + batch[feat_name], value_type, target_pandas_type + ) + except Exception as e: + logger.warning( + f"Failed to convert feature {feat_name} to proper type: {e}" + ) + continue + return batch + + return _apply_to_data(data, convert_batch) + + +class RayOfflineStoreConfig(FeastConfigBaseModel): + """ + Configuration for the Ray Offline Store. + + For detailed configuration options and examples, see the documentation: + https://docs.feast.dev/reference/offline-stores/ray + """ + + type: Literal[ + "feast.offline_stores.contrib.ray_offline_store.ray.RayOfflineStore", "ray" + ] = "ray" + storage_path: Optional[str] = None + ray_address: Optional[str] = None + + # Optimization settings + broadcast_join_threshold_mb: Optional[int] = 100 + enable_distributed_joins: Optional[bool] = True + max_parallelism_multiplier: Optional[int] = 2 + target_partition_size_mb: Optional[int] = 64 + window_size_for_joins: Optional[str] = "1H" + + # Logging settings + enable_ray_logging: Optional[bool] = False + + # Ray configuration for resource management (memory, CPU limits) + ray_conf: Optional[Dict[str, Any]] = None + + +class RayResourceManager: + """ + Manages Ray cluster resources for optimal performance. + # See: https://docs.feast.dev/reference/offline-stores/ray#resource-management-and-testing + """ + + def __init__(self, config: Optional[RayOfflineStoreConfig] = None) -> None: + """ + Initialize the resource manager with cluster resource information. + """ + self.config = config or RayOfflineStoreConfig() + self.cluster_resources = ray.cluster_resources() + self.available_memory = self.cluster_resources.get("memory", 8 * 1024**3) + self.available_cpus = int(self.cluster_resources.get("CPU", 4)) + self.num_nodes = len(ray.nodes()) if ray.is_initialized() else 1 + + def configure_ray_context(self) -> None: + """ + Configure Ray DatasetContext for optimal performance based on available resources. + """ + ctx = DatasetContext.get_current() + + if self.available_memory > 32 * 1024**3: + ctx.target_shuffle_buffer_size = 2 * 1024**3 + ctx.target_max_block_size = 512 * 1024**2 + else: + ctx.target_shuffle_buffer_size = 512 * 1024**2 + ctx.target_max_block_size = 128 * 1024**2 + ctx.min_parallelism = self.available_cpus + multiplier = ( + self.config.max_parallelism_multiplier + if self.config.max_parallelism_multiplier is not None + else 2 + ) + ctx.max_parallelism = self.available_cpus * multiplier + ctx.shuffle_strategy = "sort" # type: ignore + ctx.enable_tensor_extension_casting = False + + if not getattr(self.config, "enable_ray_logging", False): + ctx.enable_progress_bars = False + if hasattr(ctx, "verbose_progress"): + ctx.verbose_progress = False + + if getattr(self.config, "enable_ray_logging", False): + logger.info( + f"Configured Ray context: {self.available_cpus} CPUs, " + f"{self.available_memory // 1024**3}GB memory, {self.num_nodes} nodes" + ) + + def estimate_optimal_partitions(self, dataset_size_bytes: int) -> int: + """ + Estimate optimal number of partitions for a dataset based on size and resources. + """ + target_partition_size = (self.config.target_partition_size_mb or 64) * 1024**2 + size_based_partitions = max(1, dataset_size_bytes // target_partition_size) + max_partitions = self.available_cpus * ( + self.config.max_parallelism_multiplier or 2 + ) + return min(size_based_partitions, max_partitions) + + def should_use_broadcast_join( + self, dataset_size_bytes: int, threshold_mb: Optional[int] = None + ) -> bool: + """ + Determine if dataset is small enough for broadcast join. + """ + threshold = ( + threshold_mb + if threshold_mb is not None + else (self.config.broadcast_join_threshold_mb or 100) + ) + return dataset_size_bytes <= threshold * 1024**2 + + def estimate_processing_requirements( + self, dataset_size_bytes: int, operation_type: str + ) -> Dict[str, Any]: + """ + Estimate resource requirements for different operations. + """ + memory_multiplier = { + "read": 1.2, # 20% overhead for reading + "join": 3.0, # 3x for join operations + "aggregate": 2.0, # 2x for aggregations + "shuffle": 2.5, # 2.5x for shuffling + } + required_memory = dataset_size_bytes * memory_multiplier.get( + operation_type, 2.0 + ) + return { + "required_memory": required_memory, + "optimal_partitions": self.estimate_optimal_partitions(dataset_size_bytes), + "can_fit_in_memory": required_memory <= self.available_memory * 0.8, + "should_broadcast": self.should_use_broadcast_join(dataset_size_bytes), + } + + +class RayDataProcessor: + """ + Optimized data processing with Ray for feature store operations. + """ + + def __init__(self, resource_manager: RayResourceManager) -> None: + """ + Initialize the data processor with a resource manager. + """ + self.resource_manager = resource_manager + + def optimize_dataset_for_join(self, ds: Dataset, join_keys: List[str]) -> Dataset: + """ + Optimize dataset partitioning for join operations. + """ + dataset_size = ds.size_bytes() + optimal_partitions = self.resource_manager.estimate_optimal_partitions( + dataset_size + ) + if not join_keys: + # For datasets without join keys, use simple repartitioning + return ds.repartition(num_blocks=optimal_partitions) + # For datasets with join keys, use shuffle for better distribution + return ds.random_shuffle(num_blocks=optimal_partitions) + + def _manual_point_in_time_join( + self, + batch_df: pd.DataFrame, + features_df: pd.DataFrame, + join_keys: List[str], + feature_join_keys: List[str], + timestamp_field: str, + requested_feats: List[str], + ) -> pd.DataFrame: + """ + Perform manual point-in-time join when merge_asof fails. + + This method handles cases where merge_asof cannot be used due to: + - Entity mapping (different column names) + - Complex multi-entity joins + - Sorting issues with the data + """ + result = batch_df.copy() + for feat in requested_feats: + is_list_feature = False + if feat in features_df.columns: + sample_values = features_df[feat].dropna() + if not sample_values.empty: + sample_value = sample_values.iloc[0] + if isinstance(sample_value, (list, np.ndarray)): + is_list_feature = True + elif ( + features_df[feat].dtype == object + and sample_values.apply( + lambda x: isinstance(x, (list, np.ndarray)) + ).any() + ): + is_list_feature = True + + if is_list_feature: + result[feat] = [[] for _ in range(len(result))] + else: + if feat in features_df.columns and pd.api.types.is_datetime64_any_dtype( + features_df[feat] + ): + result[feat] = pd.Series( + [pd.NaT] * len(result), dtype="datetime64[ns, UTC]" + ) + else: + result[feat] = np.nan + + for _, entity_row in batch_df.iterrows(): + entity_matches = pd.Series( + [True] * len(features_df), index=features_df.index + ) + for entity_key, feature_key in zip(join_keys, feature_join_keys): + if entity_key in entity_row and feature_key in features_df.columns: + entity_value = entity_row[entity_key] + feature_column = features_df[feature_key] + if pd.api.types.is_scalar(entity_value): + entity_matches &= feature_column == entity_value + else: + if hasattr(entity_value, "__len__") and len(entity_value) > 0: + entity_matches &= feature_column.isin(entity_value) + else: + entity_matches &= pd.Series( + [False] * len(features_df), index=features_df.index + ) + if not entity_matches.any(): + continue + matching_features = features_df[entity_matches] + entity_timestamp = entity_row[timestamp_field] + if timestamp_field in matching_features.columns: + time_matches = matching_features[timestamp_field] <= entity_timestamp + matching_features = matching_features[time_matches] + if matching_features.empty: + continue + + if timestamp_field in matching_features.columns: + matching_features = matching_features.sort_values(timestamp_field) + latest_feature = matching_features.iloc[-1] + else: + latest_feature = matching_features.iloc[-1] + + entity_index = entity_row.name + for feat in requested_feats: + if feat in latest_feature: + feature_value = latest_feature[feat] + if pd.api.types.is_scalar(feature_value): + if pd.notna(feature_value): + result.loc[entity_index, feat] = feature_value + elif isinstance(feature_value, (list, tuple, np.ndarray)): + result.at[entity_index, feat] = feature_value + else: + try: + if pd.notna(feature_value): + result.at[entity_index, feat] = feature_value + except (ValueError, TypeError): + if feature_value is not None: + result.at[entity_index, feat] = feature_value + + return result + + def broadcast_join_features( + self, + entity_ds: Dataset, + feature_df: pd.DataFrame, + join_keys: List[str], + timestamp_field: str, + requested_feats: List[str], + full_feature_names: bool = False, + feature_view_name: Optional[str] = None, + original_join_keys: Optional[List[str]] = None, + ) -> Dataset: + """Perform broadcast join for small feature datasets.""" + + # Put feature data in Ray object store for efficient broadcasting + feature_ref = ray.put(feature_df) + + def join_batch_with_features(batch: pd.DataFrame) -> pd.DataFrame: + """Join a batch with broadcast feature data.""" + features = ray.get(feature_ref) + + enable_logging = getattr( + self.resource_manager.config, "enable_ray_logging", False + ) + if enable_logging: + logger.info( + f"Processing feature view {feature_view_name} with join keys {join_keys}" + ) + + if original_join_keys: + feature_join_keys = original_join_keys + entity_join_keys = join_keys + else: + feature_join_keys = join_keys + entity_join_keys = join_keys + + feature_cols = [timestamp_field] + feature_join_keys + requested_feats + + available_feature_cols = [ + col for col in feature_cols if col in features.columns + ] + + if timestamp_field not in available_feature_cols: + raise ValueError( + f"Timestamp field '{timestamp_field}' not found in features columns: {list(features.columns)}" + ) + + missing_feats = [ + feat for feat in requested_feats if feat not in features.columns + ] + if missing_feats: + raise ValueError( + f"Requested features {missing_feats} not found in features columns: {list(features.columns)}" + ) + + features_filtered = features[available_feature_cols].copy() + + batch = normalize_timestamp_columns(batch, timestamp_field, inplace=True) + features_filtered = normalize_timestamp_columns( + features_filtered, timestamp_field, inplace=True + ) + + if not entity_join_keys: + batch_sorted = batch.sort_values(timestamp_field).reset_index(drop=True) + features_sorted = features_filtered.sort_values( + timestamp_field + ).reset_index(drop=True) + result = pd.merge_asof( + batch_sorted, + features_sorted, + on=timestamp_field, + direction="backward", + ) + else: + for key in entity_join_keys: + if key not in batch.columns: + batch[key] = np.nan + for key in feature_join_keys: + if key not in features_filtered.columns: + features_filtered[key] = np.nan + batch_clean = batch.dropna( + subset=entity_join_keys + [timestamp_field] + ).copy() + features_clean = features_filtered.dropna( + subset=feature_join_keys + [timestamp_field] + ).copy() + if batch_clean.empty or features_clean.empty: + return batch.head(0) + if timestamp_field in batch_clean.columns: + batch_sorted = batch_clean.sort_values( + timestamp_field, ascending=True + ).reset_index(drop=True) + else: + batch_sorted = batch_clean.reset_index(drop=True) + + right_sort_columns = [] + for key in feature_join_keys: + if key in features_clean.columns: + right_sort_columns.append(key) + if timestamp_field in features_clean.columns: + right_sort_columns.append(timestamp_field) + if right_sort_columns: + features_clean = features_clean.drop_duplicates( + subset=right_sort_columns, keep="last" + ) + features_sorted = features_clean.sort_values( + right_sort_columns, ascending=True + ).reset_index(drop=True) + else: + features_sorted = features_clean.reset_index(drop=True) + + if ( + timestamp_field in features_sorted.columns + and len(features_sorted) > 1 + ): + if feature_join_keys: + grouped = features_sorted.groupby(feature_join_keys, sort=False) + for name, group in grouped: + if not group[timestamp_field].is_monotonic_increasing: + features_sorted = features_sorted.sort_values( + feature_join_keys + [timestamp_field], + ascending=True, + ).reset_index(drop=True) + break + else: + if not features_sorted[timestamp_field].is_monotonic_increasing: + features_sorted = features_sorted.sort_values( + timestamp_field, ascending=True + ).reset_index(drop=True) + + try: + if feature_join_keys: + batch_dedup_cols = [ + k for k in entity_join_keys if k in batch_sorted.columns + ] + if timestamp_field in batch_sorted.columns: + batch_dedup_cols.append(timestamp_field) + if batch_dedup_cols: + batch_sorted = batch_sorted.drop_duplicates( + subset=batch_dedup_cols, keep="last" + ) + feature_dedup_cols = [ + k for k in feature_join_keys if k in features_sorted.columns + ] + if timestamp_field in features_sorted.columns: + feature_dedup_cols.append(timestamp_field) + if feature_dedup_cols: + features_sorted = features_sorted.drop_duplicates( + subset=feature_dedup_cols, keep="last" + ) + + if feature_join_keys: + if entity_join_keys == feature_join_keys: + result = pd.merge_asof( + batch_sorted, + features_sorted, + on=timestamp_field, + by=entity_join_keys, + direction="backward", + suffixes=("", "_right"), + ) + else: + result = pd.merge_asof( + batch_sorted, + features_sorted, + on=timestamp_field, + left_by=entity_join_keys, + right_by=feature_join_keys, + direction="backward", + suffixes=("", "_right"), + ) + else: + result = pd.merge_asof( + batch_sorted, + features_sorted, + on=timestamp_field, + direction="backward", + suffixes=("", "_right"), + ) + + except Exception as e: + if enable_logging: + logger.warning( + f"merge_asof didn't work: {e}, implementing manual point-in-time join" + ) + result = self._manual_point_in_time_join( + batch_clean, + features_clean, + entity_join_keys, + feature_join_keys, + timestamp_field, + requested_feats, + ) + if full_feature_names and feature_view_name: + for feat in requested_feats: + if feat in result.columns: + new_name = f"{feature_view_name}__{feat}" + result[new_name] = result[feat] + result = result.drop(columns=[feat]) + + return result + + return entity_ds.map_batches(join_batch_with_features, batch_format="pandas") + + def windowed_temporal_join( + self, + entity_ds: Dataset, + feature_ds: Dataset, + join_keys: List[str], + timestamp_field: str, + requested_feats: List[str], + window_size: Optional[str] = None, + full_feature_names: bool = False, + feature_view_name: Optional[str] = None, + original_join_keys: Optional[List[str]] = None, + ) -> Dataset: + """Perform windowed temporal join for large datasets.""" + + window_size = window_size or ( + self.resource_manager.config.window_size_for_joins or "1H" + ) + entity_optimized = self.optimize_dataset_for_join(entity_ds, join_keys) + feature_optimized = self.optimize_dataset_for_join(feature_ds, join_keys) + entity_windowed = self._add_time_windows_and_source_marker( + entity_optimized, timestamp_field, "entity", window_size + ) + feature_windowed = self._add_time_windows_and_source_marker( + feature_optimized, timestamp_field, "feature", window_size + ) + combined_ds = entity_windowed.union(feature_windowed) + result_ds = combined_ds.map_batches( + self._apply_windowed_point_in_time_logic, + batch_format="pandas", + fn_kwargs={ + "timestamp_field": timestamp_field, + "join_keys": join_keys, + "requested_feats": requested_feats, + "full_feature_names": full_feature_names, + "feature_view_name": feature_view_name, + "original_join_keys": original_join_keys, + }, + ) + + return result_ds + + def _add_time_windows_and_source_marker( + self, ds: Dataset, timestamp_field: str, source_marker: str, window_size: str + ) -> Dataset: + """Add time windows and source markers to dataset.""" + + def add_window_and_source(batch: pd.DataFrame) -> pd.DataFrame: + batch = batch.copy() + if timestamp_field in batch.columns: + batch["time_window"] = ( + pd.to_datetime(batch[timestamp_field]) + .dt.floor(window_size) + .astype("datetime64[ns, UTC]") + ) + batch["_data_source"] = source_marker + return batch + + return ds.map_batches(add_window_and_source, batch_format="pandas") + + def _apply_windowed_point_in_time_logic( + self, + batch: pd.DataFrame, + timestamp_field: str, + join_keys: List[str], + requested_feats: List[str], + full_feature_names: bool = False, + feature_view_name: Optional[str] = None, + original_join_keys: Optional[List[str]] = None, + ) -> pd.DataFrame: + """Apply point-in-time correctness within time windows.""" + + if len(batch) == 0: + return pd.DataFrame() + + result_chunks = [] + group_keys = ["time_window"] + join_keys + + for group_values, group_data in batch.groupby(group_keys): + entity_data = group_data[group_data["_data_source"] == "entity"].copy() + feature_data = group_data[group_data["_data_source"] == "feature"].copy() + if len(entity_data) > 0 and len(feature_data) > 0: + entity_clean = entity_data.drop(columns=["time_window", "_data_source"]) + feature_clean = feature_data.drop( + columns=["time_window", "_data_source"] + ) + if join_keys: + merged = pd.merge_asof( + entity_clean.sort_values(join_keys + [timestamp_field]), + feature_clean.sort_values(join_keys + [timestamp_field]), + on=timestamp_field, + by=join_keys, + direction="backward", + ) + else: + merged = pd.merge_asof( + entity_clean.sort_values(timestamp_field), + feature_clean.sort_values(timestamp_field), + on=timestamp_field, + direction="backward", + ) + + result_chunks.append(merged) + elif len(entity_data) > 0: + entity_clean = entity_data.drop(columns=["time_window", "_data_source"]) + for feat in requested_feats: + if feat not in entity_clean.columns: + entity_clean[feat] = np.nan + result_chunks.append(entity_clean) + + if result_chunks: + result = pd.concat(result_chunks, ignore_index=True) + if full_feature_names and feature_view_name: + for feat in requested_feats: + if feat in result.columns: + new_name = f"{feature_view_name}__{feat}" + result[new_name] = result[feat] + result = result.drop(columns=[feat]) + + return result + else: + return pd.DataFrame() + + +class RayRetrievalJob(RetrievalJob): + def __init__( + self, + dataset_or_callable: Union[ + Dataset, pd.DataFrame, Callable[[], Union[Dataset, pd.DataFrame]] + ], + staging_location: Optional[str] = None, + config: Optional[RayOfflineStoreConfig] = None, + ): + self._dataset_or_callable = dataset_or_callable + self._staging_location = staging_location + self._config = config or RayOfflineStoreConfig() + self._cached_df: Optional[pd.DataFrame] = None + self._cached_dataset: Optional[Dataset] = None + self._metadata: Optional[RetrievalMetadata] = None + self._full_feature_names: bool = False + self._on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None + self._feature_refs: List[str] = [] + self._entity_df: Optional[pd.DataFrame] = None + self._prefer_ray_datasets: bool = True + + def _create_metadata(self) -> RetrievalMetadata: + """Create metadata from the entity DataFrame and feature references.""" + if self._entity_df is not None: + timestamp_col = _safe_infer_event_timestamp_column( + self._entity_df, "event_timestamp" + ) + min_timestamp, max_timestamp = _safe_get_entity_timestamp_bounds( + self._entity_df, timestamp_col + ) + + keys = [col for col in self._entity_df.columns if col != timestamp_col] + else: + try: + result = self._resolve() + if isinstance(result, Dataset): + timestamp_col = _safe_infer_event_timestamp_column( + result, "event_timestamp" + ) + min_timestamp, max_timestamp = _safe_get_entity_timestamp_bounds( + result, timestamp_col + ) + schema = result.schema() + keys = [col for col in schema.names if col != timestamp_col] + else: + min_timestamp = None + max_timestamp = None + keys = [] + except Exception: + min_timestamp = None + max_timestamp = None + keys = [] + + return RetrievalMetadata( + features=self._feature_refs, + keys=keys, + min_event_timestamp=min_timestamp, + max_event_timestamp=max_timestamp, + ) + + def _set_metadata_info( + self, feature_refs: List[str], entity_df: pd.DataFrame + ) -> None: + """Set the feature references and entity DataFrame for metadata creation.""" + self._feature_refs = feature_refs + self._entity_df = entity_df + + def _resolve(self) -> Union[Dataset, pd.DataFrame]: + if callable(self._dataset_or_callable): + result = self._dataset_or_callable() + else: + result = self._dataset_or_callable + return result + + def _get_ray_dataset(self) -> Dataset: + """Get the result as a Ray Dataset, converting if necessary.""" + if self._cached_dataset is not None: + return self._cached_dataset + + result = self._resolve() + if isinstance(result, Dataset): + self._cached_dataset = result + return result + elif isinstance(result, pd.DataFrame): + self._cached_dataset = ray.data.from_pandas(result) + return self._cached_dataset + else: + raise ValueError(f"Unsupported result type: {type(result)}") + + def to_df( + self, + validation_reference: Optional[ValidationReference] = None, + timeout: Optional[int] = None, + ) -> pd.DataFrame: + if self._cached_df is not None and not self.on_demand_feature_views: + df = self._cached_df + else: + if self.on_demand_feature_views: + df = super().to_df( + validation_reference=validation_reference, timeout=timeout + ) + else: + if self._prefer_ray_datasets: + ray_ds = self._get_ray_dataset() + df = ray_ds.to_pandas() + else: + result = self._resolve() + if isinstance(result, pd.DataFrame): + df = result + else: + df = result.to_pandas() + self._cached_df = df + + if validation_reference: + try: + from feast.dqm.errors import ValidationFailed + + validation_result = validation_reference.profile.validate(df) + if not validation_result.is_success: + raise ValidationFailed(validation_result) + except ImportError: + logger.warning("DQM profiler not available, skipping validation") + except Exception as e: + logger.error(f"Validation failed: {e}") + raise ValueError(f"Data validation failed: {e}") + return df + + def to_arrow( + self, + validation_reference: Optional[ValidationReference] = None, + timeout: Optional[int] = None, + ) -> pa.Table: + if self.on_demand_feature_views: + return super().to_arrow( + validation_reference=validation_reference, timeout=timeout + ) + + if self._prefer_ray_datasets: + try: + ray_ds = self._get_ray_dataset() + if hasattr(ray_ds, "to_arrow"): + return ray_ds.to_arrow() + else: + df = ray_ds.to_pandas() + return pa.Table.from_pandas(df) + except Exception: + df = self.to_df( + validation_reference=validation_reference, timeout=timeout + ) + return pa.Table.from_pandas(df) + else: + result = self._resolve() + if isinstance(result, pd.DataFrame): + return pa.Table.from_pandas(result) + else: + df = result.to_pandas() + return pa.Table.from_pandas(df) + + def to_remote_storage(self) -> list[str]: + if not self._staging_location: + raise ValueError("Staging location must be set for remote materialization.") + try: + ray_ds = self._get_ray_dataset() + RayOfflineStore._ensure_ray_initialized() + output_uri = os.path.join(self._staging_location, str(uuid.uuid4())) + ray_ds.write_parquet(output_uri) + return [output_uri] + except Exception as e: + raise RuntimeError(f"Failed to write to remote storage: {e}") + + @property + def metadata(self) -> Optional[RetrievalMetadata]: + """Return metadata information about retrieval.""" + if self._metadata is None: + self._metadata = self._create_metadata() + return self._metadata + + @property + def full_feature_names(self) -> bool: + return self._full_feature_names + + @property + def on_demand_feature_views(self) -> List[OnDemandFeatureView]: + return self._on_demand_feature_views or [] + + def to_sql(self) -> str: + raise NotImplementedError("SQL export not supported for Ray offline store") + + def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: + if self._prefer_ray_datasets: + ray_ds = self._get_ray_dataset() + return ray_ds.to_pandas() + else: + return self._resolve().to_pandas() + + def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table: + if self._prefer_ray_datasets: + ray_ds = self._get_ray_dataset() + try: + if hasattr(ray_ds, "to_arrow"): + return ray_ds.to_arrow() + else: + df = ray_ds.to_pandas() + return pa.Table.from_pandas(df) + except Exception: + df = ray_ds.to_pandas() + return pa.Table.from_pandas(df) + else: + result = self._resolve() + if isinstance(result, pd.DataFrame): + return pa.Table.from_pandas(result) + else: + df = result.to_pandas() + return pa.Table.from_pandas(df) + + def persist( + self, + storage: SavedDatasetStorage, + allow_overwrite: Optional[bool] = False, + timeout: Optional[int] = None, + ) -> str: + """Persist the dataset to storage using Ray operations.""" + + if not isinstance(storage, SavedDatasetFileStorage): + raise ValueError( + f"Ray offline store only supports SavedDatasetFileStorage, got {type(storage)}" + ) + destination_path = storage.file_options.uri + if not destination_path.startswith(("s3://", "gs://", "hdfs://")): + if not allow_overwrite and os.path.exists(destination_path): + raise SavedDatasetLocationAlreadyExists(location=destination_path) + try: + ray_ds = self._get_ray_dataset() + + if not destination_path.startswith(("s3://", "gs://", "hdfs://")): + os.makedirs(os.path.dirname(destination_path), exist_ok=True) + + ray_ds.write_parquet(destination_path) + + return destination_path + except Exception as e: + raise RuntimeError(f"Failed to persist dataset to {destination_path}: {e}") + + def materialize(self) -> None: + """Materialize the Ray dataset to improve subsequent access performance.""" + try: + ray_ds = self._get_ray_dataset() + materialized_ds = ray_ds.materialize() + self._cached_dataset = materialized_ds + + if getattr(self._config, "enable_ray_logging", False): + logger.info("Ray dataset materialized successfully") + except Exception as e: + logger.warning(f"Failed to materialize Ray dataset: {e}") + + def schema(self) -> pa.Schema: + """Get the schema of the dataset efficiently using Ray operations.""" + try: + ray_ds = self._get_ray_dataset() + return ray_ds.schema() + except Exception: + df = self.to_df() + return pa.Table.from_pandas(df).schema + + +class RayOfflineStore(OfflineStore): + def __init__(self) -> None: + self._staging_location: Optional[str] = None + self._ray_initialized: bool = False + self._resource_manager: Optional[RayResourceManager] = None + self._data_processor: Optional[RayDataProcessor] = None + + @staticmethod + def _suppress_ray_logging() -> None: + """Suppress Ray and Ray Data logging completely.""" + import warnings + + # Suppress Ray warnings + warnings.filterwarnings("ignore", category=DeprecationWarning, module="ray") + warnings.filterwarnings("ignore", category=UserWarning, module="ray") + + # Set environment variables to suppress Ray output + os.environ["RAY_DISABLE_IMPORT_WARNING"] = "1" + os.environ["RAY_SUPPRESS_UNVERIFIED_TLS_WARNING"] = "1" + os.environ["RAY_LOG_LEVEL"] = "ERROR" + os.environ["RAY_DATA_LOG_LEVEL"] = "ERROR" + os.environ["RAY_DISABLE_PROGRESS_BARS"] = "1" + + # Suppress all Ray-related loggers + ray_loggers = [ + "ray", + "ray.data", + "ray.data.dataset", + "ray.data.context", + "ray.data._internal.streaming_executor", + "ray.data._internal.execution", + "ray.data._internal", + "ray.tune", + "ray.serve", + "ray.util", + "ray._private", + ] + for logger_name in ray_loggers: + logging.getLogger(logger_name).setLevel(logging.ERROR) + + # Configure DatasetContext to disable progress bars + try: + from ray.data.context import DatasetContext + + ctx = DatasetContext.get_current() + ctx.enable_progress_bars = False + if hasattr(ctx, "verbose_progress"): + ctx.verbose_progress = False + except Exception: + pass # Ignore if Ray Data is not available + + @staticmethod + def _ensure_ray_initialized(config: Optional[RepoConfig] = None) -> None: + """Ensure Ray is initialized with proper configuration.""" + ray_config = None + if config and hasattr(config, "offline_store"): + ray_config = config.offline_store + if isinstance(ray_config, RayOfflineStoreConfig): + if not ray_config.enable_ray_logging: + RayOfflineStore._suppress_ray_logging() + + if not ray.is_initialized(): + ray_init_kwargs: Dict[str, Any] = { + "ignore_reinit_error": True, + "include_dashboard": False, + } + + if ( + ray_config + and isinstance(ray_config, RayOfflineStoreConfig) + and not ray_config.enable_ray_logging + ): + ray_init_kwargs.update( + { + "log_to_driver": False, + "logging_level": "ERROR", + } + ) + + if config and hasattr(config, "offline_store"): + if isinstance(ray_config, RayOfflineStoreConfig): + if ray_config.ray_address: + ray_init_kwargs["address"] = ray_config.ray_address + else: + ray_init_kwargs.update( + { + "_node_ip_address": os.getenv( + "RAY_NODE_IP", "127.0.0.1" + ), + "num_cpus": os.cpu_count() or 4, + } + ) + + if ray_config.ray_conf: + ray_init_kwargs.update(ray_config.ray_conf) + else: + pass # Use default initialization + + ray.init(**ray_init_kwargs) + + ctx = DatasetContext.get_current() + ctx.shuffle_strategy = "sort" # type: ignore + ctx.enable_tensor_extension_casting = False + + if ( + ray_config + and isinstance(ray_config, RayOfflineStoreConfig) + and not ray_config.enable_ray_logging + ): + RayOfflineStore._suppress_ray_logging() + + if ray.is_initialized(): + cluster_resources = ray.cluster_resources() + if ( + not ray_config + or not isinstance(ray_config, RayOfflineStoreConfig) + or ray_config.enable_ray_logging + ): + logger.info( + f"Ray cluster initialized with {cluster_resources.get('CPU', 0)} CPUs, " + f"{cluster_resources.get('memory', 0) / (1024**3):.1f}GB memory" + ) + + def _init_ray(self, config: RepoConfig) -> None: + ray_config = config.offline_store + assert isinstance(ray_config, RayOfflineStoreConfig) + RayOfflineStore._ensure_ray_initialized(config) + + if not ray_config.enable_ray_logging: + RayOfflineStore._suppress_ray_logging() + + if self._resource_manager is None: + self._resource_manager = RayResourceManager(ray_config) + self._resource_manager.configure_ray_context() + if self._data_processor is None: + self._data_processor = RayDataProcessor(self._resource_manager) + + def _get_source_path(self, source: DataSource, config: RepoConfig) -> str: + if not isinstance(source, FileSource): + raise ValueError("RayOfflineStore currently only supports FileSource") + repo_path = getattr(config, "repo_path", None) + uri = FileSource.get_uri_for_file_path(repo_path, source.path) + return uri + + def _optimize_dataset_for_operation(self, ds: Dataset, operation: str) -> Dataset: + """Optimize dataset for specific operations.""" + if self._resource_manager is None: + return ds + + dataset_size = ds.size_bytes() + requirements = self._resource_manager.estimate_processing_requirements( + dataset_size, operation + ) + + if requirements["can_fit_in_memory"]: + ds = ds.materialize() + + optimal_partitions = requirements["optimal_partitions"] + current_partitions = ds.num_blocks() + + if current_partitions != optimal_partitions: + if getattr(self._resource_manager.config, "enable_ray_logging", False): + logger.debug( + f"Repartitioning dataset from {current_partitions} to {optimal_partitions} blocks" + ) + ds = ds.repartition(num_blocks=optimal_partitions) + + return ds + + @staticmethod + def offline_write_batch( + config: RepoConfig, + feature_view: FeatureView, + table: pa.Table, + progress: Optional[Callable[[int], Any]] = None, + ) -> None: + """Write batch data using Ray operations with performance monitoring.""" + import time + + start_time = time.time() + + RayOfflineStore._ensure_ray_initialized(config) + + repo_path = getattr(config, "repo_path", None) or os.getcwd() + ray_config = config.offline_store + assert isinstance(ray_config, RayOfflineStoreConfig) + + if not ray_config.enable_ray_logging: + RayOfflineStore._suppress_ray_logging() + assert isinstance(feature_view.batch_source, FileSource) + + validation_result = _safe_validate_schema( + config, feature_view.batch_source, table.column_names, "offline_write_batch" + ) + + if validation_result: + expected_schema, expected_columns = validation_result + if expected_columns != table.column_names and set(expected_columns) == set( + table.column_names + ): + if getattr(ray_config, "enable_ray_logging", False): + logger.info("Reordering table columns to match expected schema") + table = table.select(expected_columns) + + batch_source_path = feature_view.batch_source.file_options.uri + feature_path = FileSource.get_uri_for_file_path(repo_path, batch_source_path) + + ds = ray.data.from_arrow(table) + + try: + if feature_path.endswith(".parquet"): + if os.path.exists(feature_path): + existing_ds = ray.data.read_parquet(feature_path) + combined_ds = existing_ds.union(ds) + combined_ds.write_parquet(feature_path) + else: + ds.write_parquet(feature_path) + else: + os.makedirs(feature_path, exist_ok=True) + ds.write_parquet(feature_path) + + if progress: + progress(table.num_rows) + + except Exception: + if getattr(ray_config, "enable_ray_logging", False): + logger.info("Falling back to pandas-based writing") + df = table.to_pandas() + if feature_path.endswith(".parquet"): + if os.path.exists(feature_path): + existing_df = pd.read_parquet(feature_path) + combined_df = pd.concat([existing_df, df], ignore_index=True) + combined_df.to_parquet(feature_path, index=False) + else: + df.to_parquet(feature_path, index=False) + else: + os.makedirs(feature_path, exist_ok=True) + ds_fallback = ray.data.from_pandas(df) + ds_fallback.write_parquet(feature_path) + + if progress: + progress(table.num_rows) + + duration = time.time() - start_time + if getattr(ray_config, "enable_ray_logging", False): + logger.info( + f"Ray offline_write_batch performance: {table.num_rows} rows in {duration:.2f}s " + f"({table.num_rows / duration:.0f} rows/s)" + ) + + def online_write_batch( + self, + config: RepoConfig, + table: pa.Table, + progress: Optional[Callable[[int], Any]] = None, + ) -> None: + """Ray offline store doesn't support online writes.""" + raise NotImplementedError("Ray offline store doesn't support online writes") + + @staticmethod + def _process_filtered_batch( + batch: pd.DataFrame, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_columns: List[str], + timestamp_field_mapped: str, + ) -> pd.DataFrame: + batch = make_df_tzaware(batch) + if batch.empty: + return _handle_empty_dataframe_case( + join_key_columns, feature_name_columns, timestamp_columns + ) + all_required_columns = _build_required_columns( + join_key_columns, feature_name_columns, timestamp_columns + ) + if not join_key_columns: + batch[DUMMY_ENTITY_ID] = DUMMY_ENTITY_VAL + available_columns = [ + col for col in all_required_columns if col in batch.columns + ] + batch = batch[available_columns] + if ( + "event_timestamp" not in batch.columns + and timestamp_field_mapped != "event_timestamp" + ): + if timestamp_field_mapped in batch.columns: + batch["event_timestamp"] = batch[timestamp_field_mapped] + return batch + + @staticmethod + def _load_and_filter_dataset( + source_path: str, + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_field: str, + created_timestamp_column: Optional[str], + start_date: Optional[datetime], + end_date: Optional[datetime], + ) -> pd.DataFrame: + try: + field_mapping = getattr(data_source, "field_mapping", None) + ds = RayOfflineStore._create_filtered_dataset( + source_path, timestamp_field, start_date, end_date + ) + df = ds.to_pandas() + if field_mapping: + df = df.rename(columns=field_mapping) + timestamp_field_mapped = ( + field_mapping.get(timestamp_field, timestamp_field) + if field_mapping + else timestamp_field + ) + created_timestamp_column_mapped = ( + field_mapping.get(created_timestamp_column, created_timestamp_column) + if field_mapping and created_timestamp_column + else created_timestamp_column + ) + timestamp_columns = [timestamp_field_mapped] + if created_timestamp_column_mapped: + timestamp_columns.append(created_timestamp_column_mapped) + df = normalize_timestamp_columns(df, timestamp_columns, inplace=True) + df = RayOfflineStore._process_filtered_batch( + df, + join_key_columns, + feature_name_columns, + timestamp_columns, + timestamp_field_mapped, + ) + existing_timestamp_columns = [ + col for col in timestamp_columns if col in df.columns + ] + if existing_timestamp_columns: + df = df.sort_values(existing_timestamp_columns, ascending=False) + df = df.reset_index(drop=True) + return df + except Exception as e: + raise RuntimeError(f"Failed to load data from {source_path}: {e}") + + @staticmethod + def _load_and_filter_dataset_ray( + source_path: str, + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_field: str, + created_timestamp_column: Optional[str], + start_date: Optional[datetime], + end_date: Optional[datetime], + ) -> Dataset: + try: + field_mapping = getattr(data_source, "field_mapping", None) + ds = RayOfflineStore._create_filtered_dataset( + source_path, timestamp_field, start_date, end_date + ) + if field_mapping: + ds = apply_field_mapping(ds, field_mapping) + timestamp_field_mapped = ( + field_mapping.get(timestamp_field, timestamp_field) + if field_mapping + else timestamp_field + ) + created_timestamp_column_mapped = ( + field_mapping.get(created_timestamp_column, created_timestamp_column) + if field_mapping and created_timestamp_column + else created_timestamp_column + ) + timestamp_columns = [timestamp_field_mapped] + if created_timestamp_column_mapped: + timestamp_columns.append(created_timestamp_column_mapped) + # Exclude __log_timestamp from normalization as it's used for time range filtering + exclude_columns = ( + ["__log_timestamp"] if "__log_timestamp" in timestamp_columns else [] + ) + ds = normalize_timestamp_columns( + ds, timestamp_columns, exclude_columns=exclude_columns + ) + ds = ds.map_batches( + lambda batch: RayOfflineStore._process_filtered_batch( + batch, + join_key_columns, + feature_name_columns, + timestamp_columns, + timestamp_field_mapped, + ), + batch_format="pandas", + ) + timestamp_columns_existing = [ + col for col in timestamp_columns if col in ds.schema().names + ] + if timestamp_columns_existing: + ds = ds.sort(timestamp_columns_existing, descending=True) + + return ds + except Exception as e: + raise RuntimeError(f"Failed to load data from {source_path}: {e}") + + @staticmethod + def _pull_latest_processing_ray( + ds: Dataset, + join_key_columns: List[str], + timestamp_field: str, + created_timestamp_column: Optional[str], + field_mapping: Optional[Dict[str, str]] = None, + ) -> Dataset: + """ + Ray-native processing for pull_latest operations with deduplication. + Args: + ds: Ray Dataset to process + join_key_columns: List of join key columns + timestamp_field: Name of the timestamp field + created_timestamp_column: Optional created timestamp column + field_mapping: Optional field mapping dictionary + Returns: + Ray Dataset with latest records only + """ + if not join_key_columns: + return ds + + timestamp_field_mapped = ( + field_mapping.get(timestamp_field, timestamp_field) + if field_mapping + else timestamp_field + ) + created_timestamp_column_mapped = ( + field_mapping.get(created_timestamp_column, created_timestamp_column) + if field_mapping and created_timestamp_column + else created_timestamp_column + ) + + timestamp_columns = [timestamp_field_mapped] + if created_timestamp_column_mapped: + timestamp_columns.append(created_timestamp_column_mapped) + + def deduplicate_batch(batch: pd.DataFrame) -> pd.DataFrame: + if batch.empty: + return batch + + existing_timestamp_columns = [ + col for col in timestamp_columns if col in batch.columns + ] + + sort_columns = join_key_columns + existing_timestamp_columns + if sort_columns: + batch = batch.sort_values( + sort_columns, + ascending=[True] * len(join_key_columns) + + [False] * len(existing_timestamp_columns), + ) + batch = batch.drop_duplicates(subset=join_key_columns, keep="first") + + return batch + + return ds.map_batches(deduplicate_batch, batch_format="pandas") + + @staticmethod + def pull_latest_from_table_or_query( + config: RepoConfig, + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_field: str, + created_timestamp_column: Optional[str], + start_date: datetime, + end_date: datetime, + ) -> RetrievalJob: + store = RayOfflineStore() + store._init_ray(config) + + source_path = store._get_source_path(data_source, config) + + def _load_ray_dataset(): + ds = store._load_and_filter_dataset_ray( + source_path, + data_source, + join_key_columns, + feature_name_columns, + timestamp_field, + created_timestamp_column, + start_date, + end_date, + ) + field_mapping = getattr(data_source, "field_mapping", None) + ds = store._pull_latest_processing_ray( + ds, + join_key_columns, + timestamp_field, + created_timestamp_column, + field_mapping, + ) + + return ds + + def _load_pandas_fallback(): + return store._load_and_filter_dataset( + source_path, + data_source, + join_key_columns, + feature_name_columns, + timestamp_field, + created_timestamp_column, + start_date, + end_date, + ) + + try: + return RayRetrievalJob( + _load_ray_dataset, + staging_location=config.offline_store.storage_path, + config=config.offline_store, + ) + except Exception as e: + logger.warning(f"Ray-native processing failed: {e}, falling back to pandas") + return RayRetrievalJob( + _load_pandas_fallback, + staging_location=config.offline_store.storage_path, + config=config.offline_store, + ) + + @staticmethod + def pull_all_from_table_or_query( + config: RepoConfig, + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_field: str, + created_timestamp_column: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + ) -> RetrievalJob: + store = RayOfflineStore() + store._init_ray(config) + + source_path = store._get_source_path(data_source, config) + + fs, path_in_fs = fsspec.core.url_to_fs(source_path) + if not fs.exists(path_in_fs): + raise FileNotFoundError(f"Parquet path does not exist: {source_path}") + + def _load_ray_dataset(): + return store._load_and_filter_dataset_ray( + source_path, + data_source, + join_key_columns, + feature_name_columns, + timestamp_field, + created_timestamp_column, + start_date, + end_date, + ) + + def _load_pandas_fallback(): + return store._load_and_filter_dataset( + source_path, + data_source, + join_key_columns, + feature_name_columns, + timestamp_field, + created_timestamp_column, + start_date, + end_date, + ) + + try: + return RayRetrievalJob( + _load_ray_dataset, + staging_location=config.offline_store.storage_path, + config=config.offline_store, + ) + except Exception as e: + logger.warning(f"Ray-native processing failed: {e}, falling back to pandas") + return RayRetrievalJob( + _load_pandas_fallback, + staging_location=config.offline_store.storage_path, + config=config.offline_store, + ) + + @staticmethod + def write_logged_features( + config: RepoConfig, + data: Union[pa.Table, Path], + source: LoggingSource, + logging_config: LoggingConfig, + registry: BaseRegistry, + ) -> None: + RayOfflineStore._ensure_ray_initialized(config) + + ray_config = getattr(config, "offline_store", None) + if ( + ray_config + and isinstance(ray_config, RayOfflineStoreConfig) + and not ray_config.enable_ray_logging + ): + RayOfflineStore._suppress_ray_logging() + + destination = logging_config.destination + assert isinstance(destination, FileLoggingDestination), ( + f"Ray offline store only supports FileLoggingDestination for logging, " + f"got {type(destination)}" + ) + + repo_path = getattr(config, "repo_path", None) or os.getcwd() + absolute_path = FileSource.get_uri_for_file_path(repo_path, destination.path) + + try: + if isinstance(data, Path): + ds = ray.data.read_parquet(str(data)) + else: + ds = ray.data.from_arrow(data) + + # Normalize feature timestamp precision to seconds to match test expectations during write + # Note: Don't normalize __log_timestamp as it's used for time range filtering + def normalize_timestamps(batch: pd.DataFrame) -> pd.DataFrame: + batch = batch.copy() + for col in batch.columns: + if ( + pd.api.types.is_datetime64_any_dtype(batch[col]) + and col != "__log_timestamp" + ): + batch[col] = batch[col].dt.floor("s") + return batch + + ds = ds.map_batches(normalize_timestamps, batch_format="pandas") + ds = ds.materialize() + filesystem, resolved_path = FileSource.create_filesystem_and_path( + absolute_path, destination.s3_endpoint_override + ) + path_obj = Path(resolved_path) + if path_obj.suffix == ".parquet": + path_obj = path_obj.with_suffix("") + if not absolute_path.startswith(("s3://", "gs://")): + path_obj.mkdir(parents=True, exist_ok=True) + ds.write_parquet(str(path_obj)) + except Exception as e: + raise RuntimeError(f"Failed to write logged features: {e}") + + @staticmethod + def create_saved_dataset_destination( + config: RepoConfig, + name: str, + path: Optional[str] = None, + ) -> SavedDatasetStorage: + """Create a saved dataset destination for Ray offline store.""" + + if path is None: + ray_config = config.offline_store + assert isinstance(ray_config, RayOfflineStoreConfig) + base_storage_path = ray_config.storage_path or "/tmp/ray-storage" + path = f"{base_storage_path}/saved_datasets/{name}.parquet" + + return SavedDatasetFileStorage(path=path) + + @staticmethod + def _create_filtered_dataset( + source_path: str, + timestamp_field: str, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + ) -> Dataset: + """Helper method to create a filtered dataset based on timestamp range.""" + ds = ray.data.read_parquet(source_path) + + try: + col_names = ds.schema().names + if timestamp_field not in col_names: + raise ValueError( + f"Timestamp field '{timestamp_field}' not found in columns: {col_names}" + ) + except Exception as e: + raise ValueError(f"Failed to get dataset schema: {e}") + + def normalize(dt): + return make_tzaware(dt) if dt and dt.tzinfo is None else dt + + start_date = normalize(start_date) + end_date = normalize(end_date) + + try: + if start_date and end_date: + + def filter_by_timestamp_range(batch): + return (batch[timestamp_field] >= start_date) & ( + batch[timestamp_field] <= end_date + ) + + ds = ds.filter(filter_by_timestamp_range) + elif start_date: + + def filter_by_start_date(batch): + return batch[timestamp_field] >= start_date + + ds = ds.filter(filter_by_start_date) + elif end_date: + + def filter_by_end_date(batch): + return batch[timestamp_field] <= end_date + + ds = ds.filter(filter_by_end_date) + except Exception as e: + raise RuntimeError(f"Failed to filter dataset by timestamp: {e}") + + return ds + + @staticmethod + def get_historical_features( + config: RepoConfig, + feature_views: List[FeatureView], + feature_refs: List[str], + entity_df: Union[pd.DataFrame, str], + registry: BaseRegistry, + project: str, + full_feature_names: bool = False, + ) -> RetrievalJob: + store = RayOfflineStore() + store._init_ray(config) + + # Load entity_df as Ray dataset for distributed processing + if isinstance(entity_df, str): + entity_ds = ray.data.read_csv(entity_df) + entity_df_sample = entity_ds.limit(1000).to_pandas() + else: + entity_ds = ray.data.from_pandas(entity_df) + entity_df_sample = entity_df.copy() + + entity_ds = ensure_timestamp_compatibility(entity_ds, ["event_timestamp"]) + on_demand_feature_views = OnDemandFeatureView.get_requested_odfvs( + feature_refs, project, registry + ) + for odfv in on_demand_feature_views: + odfv_request_data_schema = odfv.get_request_data_schema() + for feature_name in odfv_request_data_schema.keys(): + if feature_name not in entity_df_sample.columns: + raise RequestDataNotFoundInEntityDfException( + feature_name=feature_name, + feature_view_name=odfv.name, + ) + + odfv_names = {odfv.name for odfv in on_demand_feature_views} + regular_feature_views = [ + fv for fv in feature_views if fv.name not in odfv_names + ] + global_field_mappings = {} + for fv in regular_feature_views: + mapping = getattr(fv.batch_source, "field_mapping", None) + if mapping: + for k, v in mapping.items(): + global_field_mappings[v] = k + + if global_field_mappings: + cols_to_rename = { + v: k + for k, v in global_field_mappings.items() + if v in entity_df_sample.columns + } + if cols_to_rename: + entity_ds = apply_field_mapping(entity_ds, cols_to_rename) + + result_ds = entity_ds + for fv in regular_feature_views: + fv_feature_refs = [ + ref + for ref in feature_refs + if ref.startswith(fv.projection.name_to_use() + ":") + ] + if not fv_feature_refs: + continue + + entities = fv.entities or [] + entity_objs = [registry.get_entity(e, project) for e in entities] + original_join_keys, _, timestamp_field, created_col = _get_column_names( + fv, entity_objs + ) + + if fv.projection.join_key_map: + join_keys = [ + fv.projection.join_key_map.get(key, key) + for key in original_join_keys + ] + else: + join_keys = original_join_keys + + requested_feats = [ref.split(":", 1)[1] for ref in fv_feature_refs] + + available_feature_names = [f.name for f in fv.features] + missing_feats = [ + f for f in requested_feats if f not in available_feature_names + ] + if missing_feats: + raise KeyError( + f"Requested features {missing_feats} not found in feature view '{fv.name}' " + f"(available: {available_feature_names})" + ) + + source_info = resolve_feature_view_source_with_fallback( + fv, config, is_materialization=False + ) + + # Read from the resolved data source + source_path = store._get_source_path(source_info.data_source, config) + feature_ds = ray.data.read_parquet(source_path) + logger.info( + f"Reading feature view {fv.name}: {source_info.source_description}" + ) + + # Apply transformation if available + if source_info.has_transformation and source_info.transformation_func: + transformation_serialized = dill.dumps(source_info.transformation_func) + + def apply_transformation_with_serialized_func( + batch: pd.DataFrame, + ) -> pd.DataFrame: + if batch.empty: + return batch + try: + logger.debug( + f"Applying transformation to batch with columns: {list(batch.columns)}" + ) + transformation_func = dill.loads(transformation_serialized) + result = transformation_func(batch) + logger.debug( + f"Transformation result has columns: {list(result.columns)}" + ) + return result + except Exception as e: + logger.error(f"Transformation failed for {fv.name}: {e}") + return batch + + feature_ds = feature_ds.map_batches( + apply_transformation_with_serialized_func, batch_format="pandas" + ) + logger.info(f"Applied transformation to feature view {fv.name}") + elif source_info.has_transformation: + logger.warning( + f"Feature view {fv.name} marked as having transformation but no UDF found" + ) + + feature_size = feature_ds.size_bytes() or 0 + + field_mapping = getattr(fv.batch_source, "field_mapping", None) + if field_mapping: + feature_ds = apply_field_mapping(feature_ds, field_mapping) + join_keys = [field_mapping.get(k, k) for k in join_keys] + timestamp_field = field_mapping.get(timestamp_field, timestamp_field) + if created_col: + created_col = field_mapping.get(created_col, created_col) + + if ( + timestamp_field != "event_timestamp" + and timestamp_field not in entity_df_sample.columns + and "event_timestamp" in entity_df_sample.columns + ): + + def add_timestamp_field(batch: pd.DataFrame) -> pd.DataFrame: + batch = batch.copy() + batch[timestamp_field] = batch["event_timestamp"] + return batch + + result_ds = result_ds.map_batches( + add_timestamp_field, batch_format="pandas" + ) + result_ds = normalize_timestamp_columns(result_ds, timestamp_field) + + if store._resource_manager is None: + raise ValueError("Resource manager not initialized") + requirements = store._resource_manager.estimate_processing_requirements( + feature_size, "join" + ) + + if requirements["should_broadcast"]: + # Use broadcast join for small feature datasets + if getattr(store._resource_manager.config, "enable_ray_logging", False): + logger.info( + f"Using broadcast join for {fv.name} (size: {feature_size // 1024**2}MB)" + ) + feature_df = feature_ds.to_pandas() + feature_df = ensure_timestamp_compatibility( + feature_df, [timestamp_field] + ) + + if store._data_processor is None: + raise ValueError("Data processor not initialized") + result_ds = store._data_processor.broadcast_join_features( + result_ds, + feature_df, + join_keys, + timestamp_field, + requested_feats, + full_feature_names, + fv.projection.name_to_use(), + original_join_keys if fv.projection.join_key_map else None, + ) + else: + # Use distributed windowed join for large feature datasets + if getattr(store._resource_manager.config, "enable_ray_logging", False): + logger.info( + f"Using distributed join for {fv.name} (size: {feature_size // 1024**2}MB)" + ) + feature_ds = ensure_timestamp_compatibility( + feature_ds, [timestamp_field] + ) + + if store._data_processor is None: + raise ValueError("Data processor not initialized") + result_ds = store._data_processor.windowed_temporal_join( + result_ds, + feature_ds, + join_keys, + timestamp_field, + requested_feats, + window_size=config.offline_store.window_size_for_joins, + full_feature_names=full_feature_names, + feature_view_name=fv.projection.name_to_use(), + original_join_keys=original_join_keys + if fv.projection.join_key_map + else None, + ) + + def finalize_result(batch: pd.DataFrame) -> pd.DataFrame: + batch = batch.copy() + + existing_columns = set(batch.columns) + for col in entity_df_sample.columns: + if col not in existing_columns: + if len(batch) <= len(entity_df_sample): + batch[col] = entity_df_sample[col].iloc[: len(batch)].values + else: + repeated_values = np.tile( + entity_df_sample[col].values, + (len(batch) // len(entity_df_sample) + 1), + ) + batch[col] = repeated_values[: len(batch)] + + if "event_timestamp" not in batch.columns: + if "event_timestamp" in entity_df_sample.columns: + batch["event_timestamp"] = ( + entity_df_sample["event_timestamp"].iloc[: len(batch)].values + ) + batch = normalize_timestamp_columns( + batch, "event_timestamp", inplace=True + ) + elif timestamp_field in batch.columns: + batch["event_timestamp"] = batch[timestamp_field] + + return batch + + result_ds = result_ds.map_batches(finalize_result, batch_format="pandas") + result_ds = _convert_feature_column_types(result_ds, regular_feature_views) + + storage_path = config.offline_store.storage_path + if not storage_path: + raise ValueError("Storage path must be set in config") + + job = RayRetrievalJob( + result_ds, staging_location=storage_path, config=config.offline_store + ) + job._full_feature_names = full_feature_names + job._on_demand_feature_views = on_demand_feature_views + job._feature_refs = feature_refs + job._entity_df = entity_df_sample + job._metadata = job._create_metadata() + return job diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/__init__.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/test_ray_integration.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/test_ray_integration.py new file mode 100644 index 00000000000..5ab82f8ef47 --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/test_ray_integration.py @@ -0,0 +1,146 @@ +import pandas as pd +import pytest + +from feast.utils import _utc_now +from tests.integration.feature_repos.repo_configuration import ( + construct_universal_feature_views, +) +from tests.integration.feature_repos.universal.entities import driver + + +@pytest.mark.integration +@pytest.mark.universal_offline_stores +def test_ray_offline_store_basic_write_and_read(environment, universal_data_sources): + """Test basic write and read functionality with Ray offline store.""" + store = environment.feature_store + _, _, data_sources = universal_data_sources + feature_views = construct_universal_feature_views(data_sources) + driver_fv = feature_views.driver + store.apply([driver(), driver_fv]) + + now = _utc_now() + ts = pd.Timestamp(now).round("ms") + + # Write data to offline store + df_to_write = pd.DataFrame.from_dict( + { + "event_timestamp": [ts, ts], + "driver_id": [1001, 1002], + "conv_rate": [0.1, 0.2], + "acc_rate": [0.9, 0.8], + "avg_daily_trips": [10, 20], + "created": [ts, ts], + }, + ) + + store.write_to_offline_store( + driver_fv.name, df_to_write, allow_registry_cache=False + ) + + # Read data back + entity_df = pd.DataFrame({"driver_id": [1001, 1002], "event_timestamp": [ts, ts]}) + + result_df = store.get_historical_features( + entity_df=entity_df, + features=[ + "driver_stats:conv_rate", + "driver_stats:acc_rate", + "driver_stats:avg_daily_trips", + ], + full_feature_names=False, + ).to_df() + + assert len(result_df) == 2 + assert "conv_rate" in result_df.columns + assert "acc_rate" in result_df.columns + assert "avg_daily_trips" in result_df.columns + + +@pytest.mark.integration +@pytest.mark.universal_offline_stores +@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: f"full:{v}") +def test_ray_offline_store_historical_features( + environment, universal_data_sources, full_feature_names +): + """Test historical features retrieval with Ray offline store.""" + store = environment.feature_store + + (entities, datasets, data_sources) = universal_data_sources + feature_views = construct_universal_feature_views(data_sources) + + entity_df_with_request_data = datasets.entity_df.copy(deep=True) + entity_df_with_request_data["val_to_add"] = [ + i for i in range(len(entity_df_with_request_data)) + ] + + store.apply( + [ + driver(), + *feature_views.values(), + ] + ) + + job = store.get_historical_features( + entity_df=entity_df_with_request_data, + features=[ + "driver_stats:conv_rate", + "driver_stats:avg_daily_trips", + "customer_profile:current_balance", + "conv_rate_plus_100:conv_rate_plus_100", + ], + full_feature_names=full_feature_names, + ) + + # Test DataFrame conversion + result_df = job.to_df() + assert len(result_df) > 0 + assert "event_timestamp" in result_df.columns + + # Test Arrow conversion + result_table = job.to_arrow().to_pandas() + assert len(result_table) > 0 + assert "event_timestamp" in result_table.columns + + +@pytest.mark.integration +@pytest.mark.universal_offline_stores +def test_ray_offline_store_persist(environment, universal_data_sources): + """Test dataset persistence with Ray offline store.""" + store = environment.feature_store + + (entities, datasets, data_sources) = universal_data_sources + feature_views = construct_universal_feature_views(data_sources) + + entity_df_with_request_data = datasets.entity_df.copy(deep=True) + entity_df_with_request_data["val_to_add"] = [ + i for i in range(len(entity_df_with_request_data)) + ] + + store.apply( + [ + driver(), + *feature_views.values(), + ] + ) + + job = store.get_historical_features( + entity_df=entity_df_with_request_data, + features=[ + "driver_stats:conv_rate", + "customer_profile:current_balance", + ], + full_feature_names=False, + ) + + # Test persisting the dataset + from feast.saved_dataset import SavedDatasetFileStorage + + storage = SavedDatasetFileStorage(path="data/test_saved_dataset.parquet") + saved_path = job.persist(storage, allow_overwrite=True) + + assert saved_path == "data/test_saved_dataset.parquet" + + # Verify the saved dataset exists + import os + + assert os.path.exists(saved_path) diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_repo_configuration.py b/sdk/python/feast/infra/offline_stores/contrib/ray_repo_configuration.py new file mode 100644 index 00000000000..6e1fa66b102 --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_repo_configuration.py @@ -0,0 +1,128 @@ +import os +import tempfile +from typing import Any, Dict, Optional + +from feast.data_format import ParquetFormat +from feast.data_source import DataSource +from feast.feature_logging import LoggingDestination +from feast.infra.offline_stores.contrib.ray_offline_store.ray import ( + RayOfflineStoreConfig, +) +from feast.infra.offline_stores.file_source import ( + FileLoggingDestination, + FileSource, + SavedDatasetFileStorage, +) +from feast.repo_config import FeastConfigBaseModel +from feast.saved_dataset import SavedDatasetStorage +from tests.integration.feature_repos.integration_test_repo_config import ( + IntegrationTestRepoConfig, +) +from tests.integration.feature_repos.universal.data_source_creator import ( + DataSourceCreator, +) + + +class RayDataSourceCreator(DataSourceCreator): + def __init__(self, project_name: str, *args, **kwargs): + super().__init__(project_name, *args, **kwargs) + self.offline_store_config = RayOfflineStoreConfig( + type="ray", + storage_path="/tmp/ray-storage", + ray_address=None, + broadcast_join_threshold_mb=25, + max_parallelism_multiplier=1, + target_partition_size_mb=16, + enable_ray_logging=False, + ray_conf={ + "num_cpus": 1, + "object_store_memory": 80 * 1024 * 1024, + "_memory": 400 * 1024 * 1024, + }, + ) + self.files: list[Any] = [] + self.dirs: list[str] = [] + + def create_offline_store_config(self) -> FeastConfigBaseModel: + return self.offline_store_config + + def create_data_source( + self, + df: Any, + destination_name: str, + created_timestamp_column: Optional[Any] = "created_ts", + field_mapping: Optional[Dict[str, str]] = None, + timestamp_field: Optional[str] = "ts", + ) -> DataSource: + # For Ray, we'll use parquet files as the underlying storage + destination_name = self.get_prefixed_table_name(destination_name) + + f = tempfile.NamedTemporaryFile( + prefix=f"{self.project_name}_{destination_name}", + suffix=".parquet", + delete=False, + ) + df.to_parquet(f.name) + self.files.append(f) + + return FileSource( + file_format=ParquetFormat(), + path=f.name, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + field_mapping=field_mapping or {"ts_1": "ts"}, + ) + + def get_prefixed_table_name(self, suffix: str) -> str: + return f"{self.project_name}.{suffix}" + + def create_saved_dataset_destination(self) -> SavedDatasetStorage: + d = tempfile.mkdtemp(prefix=self.project_name) + self.dirs.append(d) + return SavedDatasetFileStorage( + path=d, + file_format=ParquetFormat(), + ) + + def create_logged_features_destination(self) -> LoggingDestination: + d = tempfile.mkdtemp(prefix=self.project_name) + self.dirs.append(d) + return FileLoggingDestination(path=d) + + def teardown(self) -> None: + # Clean up any temporary files or resources + import shutil + + for f in self.files: + f.close() + try: + os.unlink(f.name) + except OSError: + pass + + for d in self.dirs: + if os.path.exists(d): + shutil.rmtree(d) + + def get_saved_dataset_data_source(self) -> Dict[str, str]: + return { + "type": "parquet", + "path": "data/saved_dataset.parquet", + } + + @staticmethod + def xdist_groups() -> list[str]: + """ + Return xdist group names for Ray tests. + This ensures all Ray tests run on the same pytest worker to avoid OOM issues. + """ + return ["ray"] + + +# Define the full repo configurations for Ray offline store +FULL_REPO_CONFIGS = [ + IntegrationTestRepoConfig( + provider="local", + offline_store_creator=RayDataSourceCreator, + ), +] diff --git a/sdk/python/feast/infra/ray_shared_utils.py b/sdk/python/feast/infra/ray_shared_utils.py new file mode 100644 index 00000000000..6fa873ab6ae --- /dev/null +++ b/sdk/python/feast/infra/ray_shared_utils.py @@ -0,0 +1,362 @@ +from typing import Dict, List, Optional, Union + +import numpy as np +import pandas as pd +from ray.data import Dataset + + +def normalize_timestamp_columns( + data: Union[pd.DataFrame, Dataset], + columns: Union[str, List[str]], + inplace: bool = False, + exclude_columns: Optional[List[str]] = None, +) -> Union[pd.DataFrame, Dataset]: + column_list = [columns] if isinstance(columns, str) else columns + exclude_columns = exclude_columns or [] + + def apply_normalization(series: pd.Series) -> pd.Series: + return ( + pd.to_datetime(series, utc=True, errors="coerce") + .dt.floor("s") + .astype("datetime64[ns, UTC]") + ) + + if isinstance(data, Dataset): + + def normalize_batch(batch: pd.DataFrame) -> pd.DataFrame: + for column in column_list: + if ( + not batch.empty + and column in batch.columns + and column not in exclude_columns + ): + batch[column] = apply_normalization(batch[column]) + return batch + + return data.map_batches(normalize_batch, batch_format="pandas") + else: + if not inplace: + data = data.copy() + for column in column_list: + if column in data.columns and column not in exclude_columns: + data[column] = apply_normalization(data[column]) + return data + + +def ensure_timestamp_compatibility( + data: Union[pd.DataFrame, Dataset], + timestamp_fields: List[str], + inplace: bool = False, +) -> Union[pd.DataFrame, Dataset]: + from feast.utils import make_df_tzaware + + if isinstance(data, Dataset): + + def ensure_compatibility(batch: pd.DataFrame) -> pd.DataFrame: + batch = make_df_tzaware(batch) + for field in timestamp_fields: + if field in batch.columns: + batch[field] = ( + pd.to_datetime(batch[field], utc=True, errors="coerce") + .dt.floor("s") + .astype("datetime64[ns, UTC]") + ) + return batch + + return data.map_batches(ensure_compatibility, batch_format="pandas") + else: + if not inplace: + data = data.copy() + from feast.utils import make_df_tzaware + + data = make_df_tzaware(data) + for field in timestamp_fields: + if field in data.columns: + data = normalize_timestamp_columns(data, field, inplace=True) + return data + + +def apply_field_mapping( + data: Union[pd.DataFrame, Dataset], field_mapping: Dict[str, str] +) -> Union[pd.DataFrame, Dataset]: + def rename_columns(df: pd.DataFrame) -> pd.DataFrame: + return df.rename(columns=field_mapping) + + if isinstance(data, Dataset): + return data.map_batches(rename_columns, batch_format="pandas") + else: + return data.rename(columns=field_mapping) + + +def deduplicate_by_keys_and_timestamp( + data: Union[pd.DataFrame, Dataset], + join_keys: List[str], + timestamp_columns: List[str], +) -> Union[pd.DataFrame, Dataset]: + def deduplicate_batch(batch: pd.DataFrame) -> pd.DataFrame: + if batch.empty: + return batch + sort_columns = join_keys + timestamp_columns + available_columns = [col for col in sort_columns if col in batch.columns] + if available_columns: + sorted_batch = batch.sort_values( + available_columns, + ascending=[True] * len(join_keys) + [False] * len(timestamp_columns), + ) + deduped_batch = sorted_batch.drop_duplicates( + subset=join_keys, + keep="first", + ) + return deduped_batch + return batch + + if isinstance(data, Dataset): + return data.map_batches(deduplicate_batch, batch_format="pandas") + else: + return deduplicate_batch(data) + + +def broadcast_join( + entity_ds: Dataset, + feature_df: pd.DataFrame, + join_keys: List[str], + timestamp_field: str, + requested_feats: List[str], + full_feature_names: bool = False, + feature_view_name: Optional[str] = None, + original_join_keys: Optional[List[str]] = None, +) -> Dataset: + import ray + + def join_batch_with_features(batch: pd.DataFrame) -> pd.DataFrame: + features = ray.get(feature_ref) + if original_join_keys: + feature_join_keys = original_join_keys + entity_join_keys = join_keys + else: + feature_join_keys = join_keys + entity_join_keys = join_keys + feature_cols = [timestamp_field] + feature_join_keys + requested_feats + available_feature_cols = [ + col for col in feature_cols if col in features.columns + ] + features_filtered = features[available_feature_cols].copy() + + batch = normalize_timestamp_columns(batch, timestamp_field, inplace=True) + features_filtered = normalize_timestamp_columns( + features_filtered, timestamp_field, inplace=True + ) + if not entity_join_keys: + batch_sorted = batch.sort_values(timestamp_field).reset_index(drop=True) + features_sorted = features_filtered.sort_values( + timestamp_field + ).reset_index(drop=True) + result = pd.merge_asof( + batch_sorted, + features_sorted, + on=timestamp_field, + direction="backward", + ) + else: + for key in entity_join_keys: + if key not in batch.columns: + batch[key] = np.nan + for key in feature_join_keys: + if key not in features_filtered.columns: + features_filtered[key] = np.nan + batch_clean = batch.dropna( + subset=entity_join_keys + [timestamp_field] + ).copy() + features_clean = features_filtered.dropna( + subset=feature_join_keys + [timestamp_field] + ).copy() + if batch_clean.empty or features_clean.empty: + return batch.head(0) + if timestamp_field in batch_clean.columns: + batch_sorted = batch_clean.sort_values( + timestamp_field, ascending=True + ).reset_index(drop=True) + else: + batch_sorted = batch_clean.reset_index(drop=True) + right_sort_columns = [ + k for k in feature_join_keys if k in features_clean.columns + ] + if timestamp_field in features_clean.columns: + right_sort_columns.append(timestamp_field) + if right_sort_columns: + features_clean = features_clean.drop_duplicates( + subset=right_sort_columns, keep="last" + ) + features_sorted = features_clean.sort_values( + right_sort_columns, ascending=True + ).reset_index(drop=True) + else: + features_sorted = features_clean.reset_index(drop=True) + try: + if feature_join_keys: + batch_dedup_cols = [ + k for k in entity_join_keys if k in batch_sorted.columns + ] + if timestamp_field in batch_sorted.columns: + batch_dedup_cols.append(timestamp_field) + if batch_dedup_cols: + batch_sorted = batch_sorted.drop_duplicates( + subset=batch_dedup_cols, keep="last" + ) + feature_dedup_cols = [ + k for k in feature_join_keys if k in features_sorted.columns + ] + if timestamp_field in features_sorted.columns: + feature_dedup_cols.append(timestamp_field) + if feature_dedup_cols: + features_sorted = features_sorted.drop_duplicates( + subset=feature_dedup_cols, keep="last" + ) + if feature_join_keys: + if entity_join_keys == feature_join_keys: + result = pd.merge_asof( + batch_sorted, + features_sorted, + on=timestamp_field, + by=entity_join_keys, + direction="backward", + suffixes=("", "_right"), + ) + else: + result = pd.merge_asof( + batch_sorted, + features_sorted, + on=timestamp_field, + left_by=entity_join_keys, + right_by=feature_join_keys, + direction="backward", + suffixes=("", "_right"), + ) + else: + result = pd.merge_asof( + batch_sorted, + features_sorted, + on=timestamp_field, + direction="backward", + suffixes=("", "_right"), + ) + except Exception: + # fallback to manual join if needed + result = batch_clean # fallback logic can be expanded + if full_feature_names and feature_view_name: + for feat in requested_feats: + if feat in result.columns: + new_name = f"{feature_view_name}__{feat}" + result[new_name] = result[feat] + result = result.drop(columns=[feat]) + return result + + feature_ref = ray.put(feature_df) + return entity_ds.map_batches(join_batch_with_features, batch_format="pandas") + + +def distributed_windowed_join( + entity_ds: Dataset, + feature_ds: Dataset, + join_keys: List[str], + timestamp_field: str, + requested_feats: List[str], + window_size: Optional[str] = None, + full_feature_names: bool = False, + feature_view_name: Optional[str] = None, + original_join_keys: Optional[List[str]] = None, +) -> Dataset: + import pandas as pd + + def add_window_and_source(ds, timestamp_field, source_marker, window_size): + def add_window_and_source_batch(batch: pd.DataFrame) -> pd.DataFrame: + batch = batch.copy() + if timestamp_field in batch.columns: + batch["time_window"] = ( + pd.to_datetime(batch[timestamp_field]) + .dt.floor(window_size) + .astype("datetime64[ns, UTC]") + ) + batch["_data_source"] = source_marker + return batch + + return ds.map_batches(add_window_and_source_batch, batch_format="pandas") + + entity_windowed = add_window_and_source( + entity_ds, timestamp_field, "entity", window_size or "1H" + ) + feature_windowed = add_window_and_source( + feature_ds, timestamp_field, "feature", window_size or "1H" + ) + combined_ds = entity_windowed.union(feature_windowed) + + def windowed_point_in_time_logic(batch: pd.DataFrame) -> pd.DataFrame: + if len(batch) == 0: + return pd.DataFrame() + result_chunks = [] + group_keys = ["time_window"] + join_keys + for group_values, group_data in batch.groupby(group_keys): + entity_data = group_data[group_data["_data_source"] == "entity"].copy() + feature_data = group_data[group_data["_data_source"] == "feature"].copy() + if len(entity_data) > 0 and len(feature_data) > 0: + entity_clean = entity_data.drop(columns=["time_window", "_data_source"]) + feature_clean = feature_data.drop( + columns=["time_window", "_data_source"] + ) + if join_keys: + merged = pd.merge_asof( + entity_clean.sort_values(join_keys + [timestamp_field]), + feature_clean.sort_values(join_keys + [timestamp_field]), + on=timestamp_field, + by=join_keys, + direction="backward", + ) + else: + merged = pd.merge_asof( + entity_clean.sort_values(timestamp_field), + feature_clean.sort_values(timestamp_field), + on=timestamp_field, + direction="backward", + ) + result_chunks.append(merged) + elif len(entity_data) > 0: + entity_clean = entity_data.drop(columns=["time_window", "_data_source"]) + for feat in requested_feats: + if feat not in entity_clean.columns: + entity_clean[feat] = np.nan + result_chunks.append(entity_clean) + if result_chunks: + result = pd.concat(result_chunks, ignore_index=True) + if full_feature_names and feature_view_name: + for feat in requested_feats: + if feat in result.columns: + new_name = f"{feature_view_name}__{feat}" + result[new_name] = result[feat] + result = result.drop(columns=[feat]) + return result + else: + return pd.DataFrame() + + return combined_ds.map_batches(windowed_point_in_time_logic, batch_format="pandas") + + +def _build_required_columns( + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_columns: List[str], +) -> List[str]: + """ + Build list of required columns for data processing. + Args: + join_key_columns: List of join key columns + feature_name_columns: List of feature columns + timestamp_columns: List of timestamp columns + Returns: + List of all required columns + """ + all_required_columns = join_key_columns + feature_name_columns + timestamp_columns + if not join_key_columns: + all_required_columns.append("__DUMMY_ENTITY_ID__") + if "event_timestamp" not in all_required_columns: + all_required_columns.append("event_timestamp") + return all_required_columns diff --git a/sdk/python/feast/infra/registry/caching_registry.py b/sdk/python/feast/infra/registry/caching_registry.py index 23ab80ee1d8..ab7944585b6 100644 --- a/sdk/python/feast/infra/registry/caching_registry.py +++ b/sdk/python/feast/infra/registry/caching_registry.py @@ -426,19 +426,19 @@ def list_projects( def refresh(self, project: Optional[str] = None): if self._refresh_lock.locked(): - logger.info("Skipping refresh if already in progress") + logger.debug("Skipping refresh if already in progress") return try: self.cached_registry_proto = self.proto() self.cached_registry_proto_created = _utc_now() except Exception as e: - logger.error(f"Error while refreshing registry: {e}", exc_info=True) + logger.debug(f"Error while refreshing registry: {e}", exc_info=True) def _refresh_cached_registry_if_necessary(self): if self.cache_mode == "sync": # Try acquiring the lock without blocking if not self._refresh_lock.acquire(blocking=False): - logger.info( + logger.debug( "Skipping refresh if lock is already held by another thread" ) return @@ -464,10 +464,10 @@ def _refresh_cached_registry_if_necessary(self): ) ) if expired: - logger.info("Registry cache expired, so refreshing") + logger.debug("Registry cache expired, so refreshing") self.refresh() except Exception as e: - logger.error( + logger.debug( f"Error in _refresh_cached_registry_if_necessary: {e}", exc_info=True, ) diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 545d5ba4c3a..948410c8861 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -49,6 +49,7 @@ "lambda": "feast.infra.compute_engines.aws_lambda.lambda_engine.LambdaComputeEngine", "k8s": "feast.infra.compute_engines.kubernetes.k8s_engine.KubernetesComputeEngine", "spark.engine": "feast.infra.compute_engines.spark.compute.SparkComputeEngine", + "ray.engine": "feast.infra.compute_engines.ray.compute.RayComputeEngine", } LEGACY_ONLINE_STORE_CLASS_FOR_TYPE = { @@ -101,6 +102,7 @@ "remote": "feast.infra.offline_stores.remote.RemoteOfflineStore", "couchbase.offline": "feast.infra.offline_stores.contrib.couchbase_offline_store.couchbase.CouchbaseColumnarOfflineStore", "clickhouse": "feast.infra.offline_stores.contrib.clickhouse_offline_store.clickhouse.ClickhouseOfflineStore", + "ray": "feast.infra.offline_stores.contrib.ray_offline_store.ray.RayOfflineStore", } FEATURE_SERVER_CONFIG_CLASS_FOR_TYPE = { diff --git a/sdk/python/feast/transformation/pandas_transformation.py b/sdk/python/feast/transformation/pandas_transformation.py index 469ddaa7768..6e073c30100 100644 --- a/sdk/python/feast/transformation/pandas_transformation.py +++ b/sdk/python/feast/transformation/pandas_transformation.py @@ -19,29 +19,43 @@ class PandasTransformation(Transformation): def __new__( cls, - udf: Callable[[Any], Any], - udf_string: str, + udf: Optional[Callable[[Any], Any]] = None, + udf_string: Optional[str] = None, name: Optional[str] = None, tags: Optional[dict[str, str]] = None, description: str = "", owner: str = "", ) -> "PandasTransformation": - instance = super(PandasTransformation, cls).__new__( - cls, - mode=TransformationMode.PANDAS, - udf=udf, - name=name, - udf_string=udf_string, - tags=tags, - description=description, - owner=owner, + # Handle Ray deserialization where parameters may not be provided + if udf is None and udf_string is None: + # Create a bare instance for deserialization + instance = object.__new__(cls) + return cast("PandasTransformation", instance) + + # Ensure required parameters are not None before calling parent constructor + if udf is None: + raise ValueError("udf parameter cannot be None") + if udf_string is None: + raise ValueError("udf_string parameter cannot be None") + + return cast( + "PandasTransformation", + super(PandasTransformation, cls).__new__( + cls, + mode=TransformationMode.PANDAS, + udf=udf, + name=name, + udf_string=udf_string, + tags=tags, + description=description, + owner=owner, + ), ) - return cast(PandasTransformation, instance) def __init__( self, - udf: Callable[[Any], Any], - udf_string: str, + udf: Optional[Callable[[Any], Any]] = None, + udf_string: Optional[str] = None, name: Optional[str] = None, tags: Optional[dict[str, str]] = None, description: str = "", @@ -49,6 +63,17 @@ def __init__( *args, **kwargs, ): + # Handle Ray deserialization where parameters may not be provided + if udf is None and udf_string is None: + # Early return for deserialization - don't initialize + return + + # Ensure required parameters are not None before calling parent constructor + if udf is None: + raise ValueError("udf parameter cannot be None") + if udf_string is None: + raise ValueError("udf_string parameter cannot be None") + return_annotation = get_type_hints(udf).get("return", inspect._empty) if return_annotation not in (inspect._empty, pd.DataFrame): raise TypeError( diff --git a/sdk/python/feast/type_map.py b/sdk/python/feast/type_map.py index 6781c9a4301..c7f9096d9aa 100644 --- a/sdk/python/feast/type_map.py +++ b/sdk/python/feast/type_map.py @@ -1119,3 +1119,50 @@ def cb_columnar_type_to_feast_value_type(type_str: str) -> ValueType: if value == ValueType.UNKNOWN: print("unknown type:", type_str) return value + + +def convert_scalar_column( + series: pd.Series, value_type: ValueType, target_pandas_type: str +) -> pd.Series: + """Convert a scalar feature column to the appropriate pandas type.""" + if value_type == ValueType.INT32: + return pd.to_numeric(series, errors="coerce").astype("Int32") + elif value_type == ValueType.INT64: + return pd.to_numeric(series, errors="coerce").astype("Int64") + elif value_type in [ValueType.FLOAT, ValueType.DOUBLE]: + return pd.to_numeric(series, errors="coerce").astype("float64") + elif value_type == ValueType.BOOL: + return series.astype("boolean") + elif value_type == ValueType.STRING: + return series.astype("string") + elif value_type == ValueType.UNIX_TIMESTAMP: + return pd.to_datetime(series, unit="s", errors="coerce") + else: + return series.astype(target_pandas_type) + + +def convert_array_column(series: pd.Series, value_type: ValueType) -> pd.Series: + """Convert an array feature column to the appropriate type with proper empty array handling.""" + base_type_map = { + ValueType.INT32_LIST: np.int32, + ValueType.INT64_LIST: np.int64, + ValueType.FLOAT_LIST: np.float32, + ValueType.DOUBLE_LIST: np.float64, + ValueType.BOOL_LIST: np.bool_, + ValueType.STRING_LIST: object, + ValueType.BYTES_LIST: object, + ValueType.UNIX_TIMESTAMP_LIST: "datetime64[s]", + } + + target_dtype = base_type_map.get(value_type, object) + + def convert_array_item(item) -> Union[np.ndarray, Any]: + if item is None or (isinstance(item, list) and len(item) == 0): + if target_dtype == object: + return np.empty(0, dtype=object) + else: + return np.empty(0, dtype=target_dtype) # type: ignore + else: + return item + + return series.apply(convert_array_item) diff --git a/sdk/python/requirements/py3.10-ci-requirements.txt b/sdk/python/requirements/py3.10-ci-requirements.txt index 970153b304a..499feed262d 100644 --- a/sdk/python/requirements/py3.10-ci-requirements.txt +++ b/sdk/python/requirements/py3.10-ci-requirements.txt @@ -502,6 +502,7 @@ click==8.2.1 \ # geomet # great-expectations # pip-tools + # ray # typer # uvicorn clickhouse-connect==0.8.18 \ @@ -1062,6 +1063,7 @@ filelock==3.18.0 \ # via # datasets # huggingface-hub + # ray # snowflake-connector-python # torch # transformers @@ -1652,9 +1654,9 @@ httpx-sse==0.4.1 \ --hash=sha256:8f44d34414bc7b21bf3602713005c5df4917884f76072479b21f68befa4ea26e \ --hash=sha256:cba42174344c3a5b06f255ce65b350880f962d99ead85e776f23c6618a377a37 # via mcp -huggingface-hub==0.34.3 \ - --hash=sha256:5444550099e2d86e68b2898b09e85878fbd788fc2957b506c6a79ce060e39492 \ - --hash=sha256:d58130fd5aa7408480681475491c0abd7e835442082fbc3ef4d45b6c39f83853 +huggingface-hub==0.34.4 \ + --hash=sha256:9b365d781739c93ff90c359844221beef048403f1bc1f1c123c191257c3c890a \ + --hash=sha256:a4228daa6fb001be3f4f4bdaf9a0db00e1739235702848df00885c9b5742c85c # via # accelerate # datasets @@ -1795,6 +1797,7 @@ jsonschema[format-nongpl]==4.25.0 \ # jupyterlab-server # mcp # nbformat + # ray jsonschema-specifications==2025.4.1 \ --hash=sha256:4653bffbd6584f7de83a67e0d620ef16900b390ddc7939d56684d6c81e33f1af \ --hash=sha256:630159c9f4dbea161a6a2205c3011cc4f18ff381b189fff48bb39b9bf26ae608 @@ -2327,6 +2330,67 @@ msal-extensions==1.3.1 \ --hash=sha256:96d3de4d034504e969ac5e85bae8106c8373b5c6568e4c8fa7af2eca9dbe6bca \ --hash=sha256:c5b0fd10f65ef62b5f1d62f4251d51cbcaf003fcedae8c91b040a488614be1a4 # via azure-identity +msgpack==1.1.1 \ + --hash=sha256:196a736f0526a03653d829d7d4c5500a97eea3648aebfd4b6743875f28aa2af8 \ + --hash=sha256:1abfc6e949b352dadf4bce0eb78023212ec5ac42f6abfd469ce91d783c149c2a \ + --hash=sha256:1b13fe0fb4aac1aa5320cd693b297fe6fdef0e7bea5518cbc2dd5299f873ae90 \ + --hash=sha256:1d75f3807a9900a7d575d8d6674a3a47e9f227e8716256f35bc6f03fc597ffbf \ + --hash=sha256:2fbbc0b906a24038c9958a1ba7ae0918ad35b06cb449d398b76a7d08470b0ed9 \ + --hash=sha256:33be9ab121df9b6b461ff91baac6f2731f83d9b27ed948c5b9d1978ae28bf157 \ + --hash=sha256:353b6fc0c36fde68b661a12949d7d49f8f51ff5fa019c1e47c87c4ff34b080ed \ + --hash=sha256:36043272c6aede309d29d56851f8841ba907a1a3d04435e43e8a19928e243c1d \ + --hash=sha256:3765afa6bd4832fc11c3749be4ba4b69a0e8d7b728f78e68120a157a4c5d41f0 \ + --hash=sha256:3a89cd8c087ea67e64844287ea52888239cbd2940884eafd2dcd25754fb72232 \ + --hash=sha256:40eae974c873b2992fd36424a5d9407f93e97656d999f43fca9d29f820899084 \ + --hash=sha256:4147151acabb9caed4e474c3344181e91ff7a388b888f1e19ea04f7e73dc7ad5 \ + --hash=sha256:435807eeb1bc791ceb3247d13c79868deb22184e1fc4224808750f0d7d1affc1 \ + --hash=sha256:4835d17af722609a45e16037bb1d4d78b7bdf19d6c0128116d178956618c4e88 \ + --hash=sha256:4a28e8072ae9779f20427af07f53bbb8b4aa81151054e882aee333b158da8752 \ + --hash=sha256:4d3237b224b930d58e9d83c81c0dba7aacc20fcc2f89c1e5423aa0529a4cd142 \ + --hash=sha256:4df2311b0ce24f06ba253fda361f938dfecd7b961576f9be3f3fbd60e87130ac \ + --hash=sha256:4fd6b577e4541676e0cc9ddc1709d25014d3ad9a66caa19962c4f5de30fc09ef \ + --hash=sha256:500e85823a27d6d9bba1d057c871b4210c1dd6fb01fbb764e37e4e8847376323 \ + --hash=sha256:5692095123007180dca3e788bb4c399cc26626da51629a31d40207cb262e67f4 \ + --hash=sha256:5fd1b58e1431008a57247d6e7cc4faa41c3607e8e7d4aaf81f7c29ea013cb458 \ + --hash=sha256:61abccf9de335d9efd149e2fff97ed5974f2481b3353772e8e2dd3402ba2bd57 \ + --hash=sha256:61e35a55a546a1690d9d09effaa436c25ae6130573b6ee9829c37ef0f18d5e78 \ + --hash=sha256:6640fd979ca9a212e4bcdf6eb74051ade2c690b862b679bfcb60ae46e6dc4bfd \ + --hash=sha256:6d489fba546295983abd142812bda76b57e33d0b9f5d5b71c09a583285506f69 \ + --hash=sha256:6f64ae8fe7ffba251fecb8408540c34ee9df1c26674c50c4544d72dbf792e5ce \ + --hash=sha256:71ef05c1726884e44f8b1d1773604ab5d4d17729d8491403a705e649116c9558 \ + --hash=sha256:77b79ce34a2bdab2594f490c8e80dd62a02d650b91a75159a63ec413b8d104cd \ + --hash=sha256:78426096939c2c7482bf31ef15ca219a9e24460289c00dd0b94411040bb73ad2 \ + --hash=sha256:79c408fcf76a958491b4e3b103d1c417044544b68e96d06432a189b43d1215c8 \ + --hash=sha256:7a17ac1ea6ec3c7687d70201cfda3b1e8061466f28f686c24f627cae4ea8efd0 \ + --hash=sha256:7da8831f9a0fdb526621ba09a281fadc58ea12701bc709e7b8cbc362feabc295 \ + --hash=sha256:870b9a626280c86cff9c576ec0d9cbcc54a1e5ebda9cd26dab12baf41fee218c \ + --hash=sha256:88d1e966c9235c1d4e2afac21ca83933ba59537e2e2727a999bf3f515ca2af26 \ + --hash=sha256:88daaf7d146e48ec71212ce21109b66e06a98e5e44dca47d853cbfe171d6c8d2 \ + --hash=sha256:8a8b10fdb84a43e50d38057b06901ec9da52baac6983d3f709d8507f3889d43f \ + --hash=sha256:8b17ba27727a36cb73aabacaa44b13090feb88a01d012c0f4be70c00f75048b4 \ + --hash=sha256:8b65b53204fe1bd037c40c4148d00ef918eb2108d24c9aaa20bc31f9810ce0a8 \ + --hash=sha256:8ddb2bcfd1a8b9e431c8d6f4f7db0773084e107730ecf3472f1dfe9ad583f3d9 \ + --hash=sha256:96decdfc4adcbc087f5ea7ebdcfd3dee9a13358cae6e81d54be962efc38f6338 \ + --hash=sha256:996f2609ddf0142daba4cefd767d6db26958aac8439ee41db9cc0db9f4c4c3a6 \ + --hash=sha256:9d592d06e3cc2f537ceeeb23d38799c6ad83255289bb84c2e5792e5a8dea268a \ + --hash=sha256:a32747b1b39c3ac27d0670122b57e6e57f28eefb725e0b625618d1b59bf9d1e0 \ + --hash=sha256:a494554874691720ba5891c9b0b39474ba43ffb1aaf32a5dac874effb1619e1a \ + --hash=sha256:a8ef6e342c137888ebbfb233e02b8fbd689bb5b5fcc59b34711ac47ebd504478 \ + --hash=sha256:ae497b11f4c21558d95de9f64fff7053544f4d1a17731c866143ed6bb4591238 \ + --hash=sha256:b1ce7f41670c5a69e1389420436f41385b1aa2504c3b0c30620764b15dded2e7 \ + --hash=sha256:b8f93dcddb243159c9e4109c9750ba5b335ab8d48d9522c5308cd05d7e3ce600 \ + --hash=sha256:ba0c325c3f485dc54ec298d8b024e134acf07c10d494ffa24373bea729acf704 \ + --hash=sha256:bb29aaa613c0a1c40d1af111abf025f1732cab333f96f285d6a93b934738a68a \ + --hash=sha256:bba1be28247e68994355e028dcd668316db30c1f758d3241a7b903ac78dcd285 \ + --hash=sha256:cb643284ab0ed26f6957d969fe0dd8bb17beb567beb8998140b5e38a90974f6c \ + --hash=sha256:d182dac0221eb8faef2e6f44701812b467c02674a322c739355c39e94730cdbf \ + --hash=sha256:d275a9e3c81b1093c060c3837e580c37f47c51eca031f7b5fb76f7b8470f5f9b \ + --hash=sha256:d8b55ea20dc59b181d3f47103f113e6f28a5e1c89fd5b67b9140edb442ab67f2 \ + --hash=sha256:da8f41e602574ece93dbbda1fab24650d6bf2a24089f9e9dbb4f5730ec1e58ad \ + --hash=sha256:e4141c5a32b5e37905b5940aacbc59739f036930367d7acce7a64e4dec1f5e0b \ + --hash=sha256:f5be6b6bc52fad84d010cb45433720327ce886009d862f46b26d4d154001994b \ + --hash=sha256:f6d58656842e1b2ddbe07f43f56b10a60f2ba5826164910968f5933e5178af75 + # via ray multidict==6.6.3 \ --hash=sha256:02fd8f32d403a6ff13864b0851f1f523d4c988051eea0471d4f1fd8010f11134 \ --hash=sha256:04cbcce84f63b9af41bad04a54d4cc4e60e90c35b9e6ccb130be2d75b71f8c17 \ @@ -2688,6 +2752,7 @@ packaging==24.2 \ # nbconvert # pandas-gbq # pytest + # ray # scikit-image # snowflake-connector-python # sphinx @@ -3090,6 +3155,7 @@ protobuf==4.25.8 \ # proto-plus # pymilvus # qdrant-client + # ray # substrait psutil==5.9.0 \ --hash=sha256:072664401ae6e7c1bfb878c65d7282d4b4391f1bc9a56d5e03b5a490403271b5 \ @@ -3962,6 +4028,7 @@ pyyaml==6.0.2 \ # jupyter-events # kubernetes # pre-commit + # ray # responses # transformers # uvicorn @@ -4066,6 +4133,32 @@ qdrant-client==1.15.1 \ --hash=sha256:2b975099b378382f6ca1cfb43f0d59e541be6e16a5892f282a4b8de7eff5cb63 \ --hash=sha256:631f1f3caebfad0fd0c1fba98f41be81d9962b7bf3ca653bed3b727c0e0cbe0e # via feast (setup.py) +ray==2.48.0 \ + --hash=sha256:24a70f416ec0be14b975f160044805ccb48cc6bc50de632983eb8f0a8e16682b \ + --hash=sha256:25e4b79fcc8f849d72db1acc4f03f37008c5c0b745df63d8a30cd35676b6545e \ + --hash=sha256:33bda4753ad0acd2b524c9158089d43486cd44cc59fe970466435bc2968fde2d \ + --hash=sha256:46d4b42a58492dec79caad2d562344689a4f99a828aeea811a0cd2cd653553ef \ + --hash=sha256:4b9b92ac29635f555ef341347d9a63dbf02b7d946347239af3c09e364bc45cf8 \ + --hash=sha256:5742b72a514afe5d60f41330200cd508376e16c650f6962e62337aa482d6a0c6 \ + --hash=sha256:5a6f57126eac9dd3286289e07e91e87b054792f9698b6f7ccab88b624816b542 \ + --hash=sha256:622e6bcdb78d98040d87bea94e65d0bb6ccc0ae1b43294c6bd69f542bf28e092 \ + --hash=sha256:649ed9442dc2d39135c593b6cf0c38e8355170b92672365ab7a3cbc958c42634 \ + --hash=sha256:6ca2b9ce45ad360cbe2996982fb22691ecfe6553ec8f97a2548295f0f96aac78 \ + --hash=sha256:8de799f3b0896f48d306d5e4a04fc6037a08c495d45f9c79935344e5693e3cf8 \ + --hash=sha256:a42ed3b640f4b599a3fc8067c83ee60497c0f03d070d7a7df02a388fa17a546b \ + --hash=sha256:a45de103173c2ed6a0defd7a2919a2bbe531fd5bf6619860cd111ca4a16e9288 \ + --hash=sha256:a7a6d830d9dc5ae8bb156fcde9a1adab7f4edb004f03918a724d885eceb8264d \ + --hash=sha256:b37a0fea4094f95d5926b1d7245abd70deb62882da3d738f9f9b76214894745c \ + --hash=sha256:b427dead5f8ad96d494d3a006d92ea2f8f16be5e6303b590e12234b37f96fbc2 \ + --hash=sha256:b94500fe2d17e491fe2e9bd4a3bf62df217e21a8f2845033c353d4d2ea240f73 \ + --hash=sha256:be45690565907c4aa035d753d82f6ff892d1e6830057b67399542a035b3682f0 \ + --hash=sha256:cfb48c10371c267fdcf7f4ae359cab706f068178b9c65317ead011972f2c0bf3 \ + --hash=sha256:e15fdffa6b60d5729f6025691396b8a01dc3461ba19dc92bba354ec1813ed6b1 \ + --hash=sha256:e6543fb3450a71862cfa1e7a666025d751f81685602cc87d499072ccd839507d \ + --hash=sha256:ea9d7739ae8f6db48b226bbc2a592640f7f2b6d854ff73d0305774b98fa9fb11 \ + --hash=sha256:f1cf33d260316f92f77558185f1c36fc35506d76ee7fdfed9f5b70f9c4bdba7f \ + --hash=sha256:f820950bc44d7b000c223342f5c800c9c08e7fd89524201125388ea211caad1a + # via feast (setup.py) redis==4.6.0 \ --hash=sha256:585dc516b9eb042a619ef0a39c3d7d55fe81bdb4df09a52c9cdde0d07bf1aa7d \ --hash=sha256:e2b03db868160ee4591de3cb90d40ebb50a90dd302138775937f6a42b7ed183c @@ -4189,6 +4282,7 @@ requests==2.32.4 \ # moto # msal # python-keycloak + # ray # requests-oauthlib # requests-toolbelt # responses diff --git a/sdk/python/requirements/py3.11-ci-requirements.txt b/sdk/python/requirements/py3.11-ci-requirements.txt index a44b6551fbf..a14ea9657a9 100644 --- a/sdk/python/requirements/py3.11-ci-requirements.txt +++ b/sdk/python/requirements/py3.11-ci-requirements.txt @@ -500,6 +500,7 @@ click==8.2.1 \ # geomet # great-expectations # pip-tools + # ray # typer # uvicorn clickhouse-connect==0.8.18 \ @@ -1053,6 +1054,7 @@ filelock==3.18.0 \ # via # datasets # huggingface-hub + # ray # snowflake-connector-python # torch # transformers @@ -1643,9 +1645,9 @@ httpx-sse==0.4.1 \ --hash=sha256:8f44d34414bc7b21bf3602713005c5df4917884f76072479b21f68befa4ea26e \ --hash=sha256:cba42174344c3a5b06f255ce65b350880f962d99ead85e776f23c6618a377a37 # via mcp -huggingface-hub==0.34.3 \ - --hash=sha256:5444550099e2d86e68b2898b09e85878fbd788fc2957b506c6a79ce060e39492 \ - --hash=sha256:d58130fd5aa7408480681475491c0abd7e835442082fbc3ef4d45b6c39f83853 +huggingface-hub==0.34.4 \ + --hash=sha256:9b365d781739c93ff90c359844221beef048403f1bc1f1c123c191257c3c890a \ + --hash=sha256:a4228daa6fb001be3f4f4bdaf9a0db00e1739235702848df00885c9b5742c85c # via # accelerate # datasets @@ -1788,6 +1790,7 @@ jsonschema[format-nongpl]==4.25.0 \ # jupyterlab-server # mcp # nbformat + # ray jsonschema-specifications==2025.4.1 \ --hash=sha256:4653bffbd6584f7de83a67e0d620ef16900b390ddc7939d56684d6c81e33f1af \ --hash=sha256:630159c9f4dbea161a6a2205c3011cc4f18ff381b189fff48bb39b9bf26ae608 @@ -2320,6 +2323,67 @@ msal-extensions==1.3.1 \ --hash=sha256:96d3de4d034504e969ac5e85bae8106c8373b5c6568e4c8fa7af2eca9dbe6bca \ --hash=sha256:c5b0fd10f65ef62b5f1d62f4251d51cbcaf003fcedae8c91b040a488614be1a4 # via azure-identity +msgpack==1.1.1 \ + --hash=sha256:196a736f0526a03653d829d7d4c5500a97eea3648aebfd4b6743875f28aa2af8 \ + --hash=sha256:1abfc6e949b352dadf4bce0eb78023212ec5ac42f6abfd469ce91d783c149c2a \ + --hash=sha256:1b13fe0fb4aac1aa5320cd693b297fe6fdef0e7bea5518cbc2dd5299f873ae90 \ + --hash=sha256:1d75f3807a9900a7d575d8d6674a3a47e9f227e8716256f35bc6f03fc597ffbf \ + --hash=sha256:2fbbc0b906a24038c9958a1ba7ae0918ad35b06cb449d398b76a7d08470b0ed9 \ + --hash=sha256:33be9ab121df9b6b461ff91baac6f2731f83d9b27ed948c5b9d1978ae28bf157 \ + --hash=sha256:353b6fc0c36fde68b661a12949d7d49f8f51ff5fa019c1e47c87c4ff34b080ed \ + --hash=sha256:36043272c6aede309d29d56851f8841ba907a1a3d04435e43e8a19928e243c1d \ + --hash=sha256:3765afa6bd4832fc11c3749be4ba4b69a0e8d7b728f78e68120a157a4c5d41f0 \ + --hash=sha256:3a89cd8c087ea67e64844287ea52888239cbd2940884eafd2dcd25754fb72232 \ + --hash=sha256:40eae974c873b2992fd36424a5d9407f93e97656d999f43fca9d29f820899084 \ + --hash=sha256:4147151acabb9caed4e474c3344181e91ff7a388b888f1e19ea04f7e73dc7ad5 \ + --hash=sha256:435807eeb1bc791ceb3247d13c79868deb22184e1fc4224808750f0d7d1affc1 \ + --hash=sha256:4835d17af722609a45e16037bb1d4d78b7bdf19d6c0128116d178956618c4e88 \ + --hash=sha256:4a28e8072ae9779f20427af07f53bbb8b4aa81151054e882aee333b158da8752 \ + --hash=sha256:4d3237b224b930d58e9d83c81c0dba7aacc20fcc2f89c1e5423aa0529a4cd142 \ + --hash=sha256:4df2311b0ce24f06ba253fda361f938dfecd7b961576f9be3f3fbd60e87130ac \ + --hash=sha256:4fd6b577e4541676e0cc9ddc1709d25014d3ad9a66caa19962c4f5de30fc09ef \ + --hash=sha256:500e85823a27d6d9bba1d057c871b4210c1dd6fb01fbb764e37e4e8847376323 \ + --hash=sha256:5692095123007180dca3e788bb4c399cc26626da51629a31d40207cb262e67f4 \ + --hash=sha256:5fd1b58e1431008a57247d6e7cc4faa41c3607e8e7d4aaf81f7c29ea013cb458 \ + --hash=sha256:61abccf9de335d9efd149e2fff97ed5974f2481b3353772e8e2dd3402ba2bd57 \ + --hash=sha256:61e35a55a546a1690d9d09effaa436c25ae6130573b6ee9829c37ef0f18d5e78 \ + --hash=sha256:6640fd979ca9a212e4bcdf6eb74051ade2c690b862b679bfcb60ae46e6dc4bfd \ + --hash=sha256:6d489fba546295983abd142812bda76b57e33d0b9f5d5b71c09a583285506f69 \ + --hash=sha256:6f64ae8fe7ffba251fecb8408540c34ee9df1c26674c50c4544d72dbf792e5ce \ + --hash=sha256:71ef05c1726884e44f8b1d1773604ab5d4d17729d8491403a705e649116c9558 \ + --hash=sha256:77b79ce34a2bdab2594f490c8e80dd62a02d650b91a75159a63ec413b8d104cd \ + --hash=sha256:78426096939c2c7482bf31ef15ca219a9e24460289c00dd0b94411040bb73ad2 \ + --hash=sha256:79c408fcf76a958491b4e3b103d1c417044544b68e96d06432a189b43d1215c8 \ + --hash=sha256:7a17ac1ea6ec3c7687d70201cfda3b1e8061466f28f686c24f627cae4ea8efd0 \ + --hash=sha256:7da8831f9a0fdb526621ba09a281fadc58ea12701bc709e7b8cbc362feabc295 \ + --hash=sha256:870b9a626280c86cff9c576ec0d9cbcc54a1e5ebda9cd26dab12baf41fee218c \ + --hash=sha256:88d1e966c9235c1d4e2afac21ca83933ba59537e2e2727a999bf3f515ca2af26 \ + --hash=sha256:88daaf7d146e48ec71212ce21109b66e06a98e5e44dca47d853cbfe171d6c8d2 \ + --hash=sha256:8a8b10fdb84a43e50d38057b06901ec9da52baac6983d3f709d8507f3889d43f \ + --hash=sha256:8b17ba27727a36cb73aabacaa44b13090feb88a01d012c0f4be70c00f75048b4 \ + --hash=sha256:8b65b53204fe1bd037c40c4148d00ef918eb2108d24c9aaa20bc31f9810ce0a8 \ + --hash=sha256:8ddb2bcfd1a8b9e431c8d6f4f7db0773084e107730ecf3472f1dfe9ad583f3d9 \ + --hash=sha256:96decdfc4adcbc087f5ea7ebdcfd3dee9a13358cae6e81d54be962efc38f6338 \ + --hash=sha256:996f2609ddf0142daba4cefd767d6db26958aac8439ee41db9cc0db9f4c4c3a6 \ + --hash=sha256:9d592d06e3cc2f537ceeeb23d38799c6ad83255289bb84c2e5792e5a8dea268a \ + --hash=sha256:a32747b1b39c3ac27d0670122b57e6e57f28eefb725e0b625618d1b59bf9d1e0 \ + --hash=sha256:a494554874691720ba5891c9b0b39474ba43ffb1aaf32a5dac874effb1619e1a \ + --hash=sha256:a8ef6e342c137888ebbfb233e02b8fbd689bb5b5fcc59b34711ac47ebd504478 \ + --hash=sha256:ae497b11f4c21558d95de9f64fff7053544f4d1a17731c866143ed6bb4591238 \ + --hash=sha256:b1ce7f41670c5a69e1389420436f41385b1aa2504c3b0c30620764b15dded2e7 \ + --hash=sha256:b8f93dcddb243159c9e4109c9750ba5b335ab8d48d9522c5308cd05d7e3ce600 \ + --hash=sha256:ba0c325c3f485dc54ec298d8b024e134acf07c10d494ffa24373bea729acf704 \ + --hash=sha256:bb29aaa613c0a1c40d1af111abf025f1732cab333f96f285d6a93b934738a68a \ + --hash=sha256:bba1be28247e68994355e028dcd668316db30c1f758d3241a7b903ac78dcd285 \ + --hash=sha256:cb643284ab0ed26f6957d969fe0dd8bb17beb567beb8998140b5e38a90974f6c \ + --hash=sha256:d182dac0221eb8faef2e6f44701812b467c02674a322c739355c39e94730cdbf \ + --hash=sha256:d275a9e3c81b1093c060c3837e580c37f47c51eca031f7b5fb76f7b8470f5f9b \ + --hash=sha256:d8b55ea20dc59b181d3f47103f113e6f28a5e1c89fd5b67b9140edb442ab67f2 \ + --hash=sha256:da8f41e602574ece93dbbda1fab24650d6bf2a24089f9e9dbb4f5730ec1e58ad \ + --hash=sha256:e4141c5a32b5e37905b5940aacbc59739f036930367d7acce7a64e4dec1f5e0b \ + --hash=sha256:f5be6b6bc52fad84d010cb45433720327ce886009d862f46b26d4d154001994b \ + --hash=sha256:f6d58656842e1b2ddbe07f43f56b10a60f2ba5826164910968f5933e5178af75 + # via ray multidict==6.6.3 \ --hash=sha256:02fd8f32d403a6ff13864b0851f1f523d4c988051eea0471d4f1fd8010f11134 \ --hash=sha256:04cbcce84f63b9af41bad04a54d4cc4e60e90c35b9e6ccb130be2d75b71f8c17 \ @@ -2700,6 +2764,7 @@ packaging==24.2 \ # nbconvert # pandas-gbq # pytest + # ray # scikit-image # snowflake-connector-python # sphinx @@ -3102,6 +3167,7 @@ protobuf==4.25.8 \ # proto-plus # pymilvus # qdrant-client + # ray # substrait psutil==5.9.0 \ --hash=sha256:072664401ae6e7c1bfb878c65d7282d4b4391f1bc9a56d5e03b5a490403271b5 \ @@ -3975,6 +4041,7 @@ pyyaml==6.0.2 \ # jupyter-events # kubernetes # pre-commit + # ray # responses # transformers # uvicorn @@ -4079,6 +4146,32 @@ qdrant-client==1.15.1 \ --hash=sha256:2b975099b378382f6ca1cfb43f0d59e541be6e16a5892f282a4b8de7eff5cb63 \ --hash=sha256:631f1f3caebfad0fd0c1fba98f41be81d9962b7bf3ca653bed3b727c0e0cbe0e # via feast (setup.py) +ray==2.48.0 \ + --hash=sha256:24a70f416ec0be14b975f160044805ccb48cc6bc50de632983eb8f0a8e16682b \ + --hash=sha256:25e4b79fcc8f849d72db1acc4f03f37008c5c0b745df63d8a30cd35676b6545e \ + --hash=sha256:33bda4753ad0acd2b524c9158089d43486cd44cc59fe970466435bc2968fde2d \ + --hash=sha256:46d4b42a58492dec79caad2d562344689a4f99a828aeea811a0cd2cd653553ef \ + --hash=sha256:4b9b92ac29635f555ef341347d9a63dbf02b7d946347239af3c09e364bc45cf8 \ + --hash=sha256:5742b72a514afe5d60f41330200cd508376e16c650f6962e62337aa482d6a0c6 \ + --hash=sha256:5a6f57126eac9dd3286289e07e91e87b054792f9698b6f7ccab88b624816b542 \ + --hash=sha256:622e6bcdb78d98040d87bea94e65d0bb6ccc0ae1b43294c6bd69f542bf28e092 \ + --hash=sha256:649ed9442dc2d39135c593b6cf0c38e8355170b92672365ab7a3cbc958c42634 \ + --hash=sha256:6ca2b9ce45ad360cbe2996982fb22691ecfe6553ec8f97a2548295f0f96aac78 \ + --hash=sha256:8de799f3b0896f48d306d5e4a04fc6037a08c495d45f9c79935344e5693e3cf8 \ + --hash=sha256:a42ed3b640f4b599a3fc8067c83ee60497c0f03d070d7a7df02a388fa17a546b \ + --hash=sha256:a45de103173c2ed6a0defd7a2919a2bbe531fd5bf6619860cd111ca4a16e9288 \ + --hash=sha256:a7a6d830d9dc5ae8bb156fcde9a1adab7f4edb004f03918a724d885eceb8264d \ + --hash=sha256:b37a0fea4094f95d5926b1d7245abd70deb62882da3d738f9f9b76214894745c \ + --hash=sha256:b427dead5f8ad96d494d3a006d92ea2f8f16be5e6303b590e12234b37f96fbc2 \ + --hash=sha256:b94500fe2d17e491fe2e9bd4a3bf62df217e21a8f2845033c353d4d2ea240f73 \ + --hash=sha256:be45690565907c4aa035d753d82f6ff892d1e6830057b67399542a035b3682f0 \ + --hash=sha256:cfb48c10371c267fdcf7f4ae359cab706f068178b9c65317ead011972f2c0bf3 \ + --hash=sha256:e15fdffa6b60d5729f6025691396b8a01dc3461ba19dc92bba354ec1813ed6b1 \ + --hash=sha256:e6543fb3450a71862cfa1e7a666025d751f81685602cc87d499072ccd839507d \ + --hash=sha256:ea9d7739ae8f6db48b226bbc2a592640f7f2b6d854ff73d0305774b98fa9fb11 \ + --hash=sha256:f1cf33d260316f92f77558185f1c36fc35506d76ee7fdfed9f5b70f9c4bdba7f \ + --hash=sha256:f820950bc44d7b000c223342f5c800c9c08e7fd89524201125388ea211caad1a + # via feast (setup.py) redis==4.6.0 \ --hash=sha256:585dc516b9eb042a619ef0a39c3d7d55fe81bdb4df09a52c9cdde0d07bf1aa7d \ --hash=sha256:e2b03db868160ee4591de3cb90d40ebb50a90dd302138775937f6a42b7ed183c @@ -4202,6 +4295,7 @@ requests==2.32.4 \ # moto # msal # python-keycloak + # ray # requests-oauthlib # requests-toolbelt # responses diff --git a/sdk/python/requirements/py3.12-ci-requirements.txt b/sdk/python/requirements/py3.12-ci-requirements.txt index c1dcdeadf07..5f6cb40390f 100644 --- a/sdk/python/requirements/py3.12-ci-requirements.txt +++ b/sdk/python/requirements/py3.12-ci-requirements.txt @@ -496,6 +496,7 @@ click==8.2.1 \ # geomet # great-expectations # pip-tools + # ray # typer # uvicorn clickhouse-connect==0.8.18 \ @@ -1049,6 +1050,7 @@ filelock==3.18.0 \ # via # datasets # huggingface-hub + # ray # snowflake-connector-python # torch # transformers @@ -1639,9 +1641,9 @@ httpx-sse==0.4.1 \ --hash=sha256:8f44d34414bc7b21bf3602713005c5df4917884f76072479b21f68befa4ea26e \ --hash=sha256:cba42174344c3a5b06f255ce65b350880f962d99ead85e776f23c6618a377a37 # via mcp -huggingface-hub==0.34.3 \ - --hash=sha256:5444550099e2d86e68b2898b09e85878fbd788fc2957b506c6a79ce060e39492 \ - --hash=sha256:d58130fd5aa7408480681475491c0abd7e835442082fbc3ef4d45b6c39f83853 +huggingface-hub==0.34.4 \ + --hash=sha256:9b365d781739c93ff90c359844221beef048403f1bc1f1c123c191257c3c890a \ + --hash=sha256:a4228daa6fb001be3f4f4bdaf9a0db00e1739235702848df00885c9b5742c85c # via # accelerate # datasets @@ -1780,6 +1782,7 @@ jsonschema[format-nongpl]==4.25.0 \ # jupyterlab-server # mcp # nbformat + # ray jsonschema-specifications==2025.4.1 \ --hash=sha256:4653bffbd6584f7de83a67e0d620ef16900b390ddc7939d56684d6c81e33f1af \ --hash=sha256:630159c9f4dbea161a6a2205c3011cc4f18ff381b189fff48bb39b9bf26ae608 @@ -2312,6 +2315,67 @@ msal-extensions==1.3.1 \ --hash=sha256:96d3de4d034504e969ac5e85bae8106c8373b5c6568e4c8fa7af2eca9dbe6bca \ --hash=sha256:c5b0fd10f65ef62b5f1d62f4251d51cbcaf003fcedae8c91b040a488614be1a4 # via azure-identity +msgpack==1.1.1 \ + --hash=sha256:196a736f0526a03653d829d7d4c5500a97eea3648aebfd4b6743875f28aa2af8 \ + --hash=sha256:1abfc6e949b352dadf4bce0eb78023212ec5ac42f6abfd469ce91d783c149c2a \ + --hash=sha256:1b13fe0fb4aac1aa5320cd693b297fe6fdef0e7bea5518cbc2dd5299f873ae90 \ + --hash=sha256:1d75f3807a9900a7d575d8d6674a3a47e9f227e8716256f35bc6f03fc597ffbf \ + --hash=sha256:2fbbc0b906a24038c9958a1ba7ae0918ad35b06cb449d398b76a7d08470b0ed9 \ + --hash=sha256:33be9ab121df9b6b461ff91baac6f2731f83d9b27ed948c5b9d1978ae28bf157 \ + --hash=sha256:353b6fc0c36fde68b661a12949d7d49f8f51ff5fa019c1e47c87c4ff34b080ed \ + --hash=sha256:36043272c6aede309d29d56851f8841ba907a1a3d04435e43e8a19928e243c1d \ + --hash=sha256:3765afa6bd4832fc11c3749be4ba4b69a0e8d7b728f78e68120a157a4c5d41f0 \ + --hash=sha256:3a89cd8c087ea67e64844287ea52888239cbd2940884eafd2dcd25754fb72232 \ + --hash=sha256:40eae974c873b2992fd36424a5d9407f93e97656d999f43fca9d29f820899084 \ + --hash=sha256:4147151acabb9caed4e474c3344181e91ff7a388b888f1e19ea04f7e73dc7ad5 \ + --hash=sha256:435807eeb1bc791ceb3247d13c79868deb22184e1fc4224808750f0d7d1affc1 \ + --hash=sha256:4835d17af722609a45e16037bb1d4d78b7bdf19d6c0128116d178956618c4e88 \ + --hash=sha256:4a28e8072ae9779f20427af07f53bbb8b4aa81151054e882aee333b158da8752 \ + --hash=sha256:4d3237b224b930d58e9d83c81c0dba7aacc20fcc2f89c1e5423aa0529a4cd142 \ + --hash=sha256:4df2311b0ce24f06ba253fda361f938dfecd7b961576f9be3f3fbd60e87130ac \ + --hash=sha256:4fd6b577e4541676e0cc9ddc1709d25014d3ad9a66caa19962c4f5de30fc09ef \ + --hash=sha256:500e85823a27d6d9bba1d057c871b4210c1dd6fb01fbb764e37e4e8847376323 \ + --hash=sha256:5692095123007180dca3e788bb4c399cc26626da51629a31d40207cb262e67f4 \ + --hash=sha256:5fd1b58e1431008a57247d6e7cc4faa41c3607e8e7d4aaf81f7c29ea013cb458 \ + --hash=sha256:61abccf9de335d9efd149e2fff97ed5974f2481b3353772e8e2dd3402ba2bd57 \ + --hash=sha256:61e35a55a546a1690d9d09effaa436c25ae6130573b6ee9829c37ef0f18d5e78 \ + --hash=sha256:6640fd979ca9a212e4bcdf6eb74051ade2c690b862b679bfcb60ae46e6dc4bfd \ + --hash=sha256:6d489fba546295983abd142812bda76b57e33d0b9f5d5b71c09a583285506f69 \ + --hash=sha256:6f64ae8fe7ffba251fecb8408540c34ee9df1c26674c50c4544d72dbf792e5ce \ + --hash=sha256:71ef05c1726884e44f8b1d1773604ab5d4d17729d8491403a705e649116c9558 \ + --hash=sha256:77b79ce34a2bdab2594f490c8e80dd62a02d650b91a75159a63ec413b8d104cd \ + --hash=sha256:78426096939c2c7482bf31ef15ca219a9e24460289c00dd0b94411040bb73ad2 \ + --hash=sha256:79c408fcf76a958491b4e3b103d1c417044544b68e96d06432a189b43d1215c8 \ + --hash=sha256:7a17ac1ea6ec3c7687d70201cfda3b1e8061466f28f686c24f627cae4ea8efd0 \ + --hash=sha256:7da8831f9a0fdb526621ba09a281fadc58ea12701bc709e7b8cbc362feabc295 \ + --hash=sha256:870b9a626280c86cff9c576ec0d9cbcc54a1e5ebda9cd26dab12baf41fee218c \ + --hash=sha256:88d1e966c9235c1d4e2afac21ca83933ba59537e2e2727a999bf3f515ca2af26 \ + --hash=sha256:88daaf7d146e48ec71212ce21109b66e06a98e5e44dca47d853cbfe171d6c8d2 \ + --hash=sha256:8a8b10fdb84a43e50d38057b06901ec9da52baac6983d3f709d8507f3889d43f \ + --hash=sha256:8b17ba27727a36cb73aabacaa44b13090feb88a01d012c0f4be70c00f75048b4 \ + --hash=sha256:8b65b53204fe1bd037c40c4148d00ef918eb2108d24c9aaa20bc31f9810ce0a8 \ + --hash=sha256:8ddb2bcfd1a8b9e431c8d6f4f7db0773084e107730ecf3472f1dfe9ad583f3d9 \ + --hash=sha256:96decdfc4adcbc087f5ea7ebdcfd3dee9a13358cae6e81d54be962efc38f6338 \ + --hash=sha256:996f2609ddf0142daba4cefd767d6db26958aac8439ee41db9cc0db9f4c4c3a6 \ + --hash=sha256:9d592d06e3cc2f537ceeeb23d38799c6ad83255289bb84c2e5792e5a8dea268a \ + --hash=sha256:a32747b1b39c3ac27d0670122b57e6e57f28eefb725e0b625618d1b59bf9d1e0 \ + --hash=sha256:a494554874691720ba5891c9b0b39474ba43ffb1aaf32a5dac874effb1619e1a \ + --hash=sha256:a8ef6e342c137888ebbfb233e02b8fbd689bb5b5fcc59b34711ac47ebd504478 \ + --hash=sha256:ae497b11f4c21558d95de9f64fff7053544f4d1a17731c866143ed6bb4591238 \ + --hash=sha256:b1ce7f41670c5a69e1389420436f41385b1aa2504c3b0c30620764b15dded2e7 \ + --hash=sha256:b8f93dcddb243159c9e4109c9750ba5b335ab8d48d9522c5308cd05d7e3ce600 \ + --hash=sha256:ba0c325c3f485dc54ec298d8b024e134acf07c10d494ffa24373bea729acf704 \ + --hash=sha256:bb29aaa613c0a1c40d1af111abf025f1732cab333f96f285d6a93b934738a68a \ + --hash=sha256:bba1be28247e68994355e028dcd668316db30c1f758d3241a7b903ac78dcd285 \ + --hash=sha256:cb643284ab0ed26f6957d969fe0dd8bb17beb567beb8998140b5e38a90974f6c \ + --hash=sha256:d182dac0221eb8faef2e6f44701812b467c02674a322c739355c39e94730cdbf \ + --hash=sha256:d275a9e3c81b1093c060c3837e580c37f47c51eca031f7b5fb76f7b8470f5f9b \ + --hash=sha256:d8b55ea20dc59b181d3f47103f113e6f28a5e1c89fd5b67b9140edb442ab67f2 \ + --hash=sha256:da8f41e602574ece93dbbda1fab24650d6bf2a24089f9e9dbb4f5730ec1e58ad \ + --hash=sha256:e4141c5a32b5e37905b5940aacbc59739f036930367d7acce7a64e4dec1f5e0b \ + --hash=sha256:f5be6b6bc52fad84d010cb45433720327ce886009d862f46b26d4d154001994b \ + --hash=sha256:f6d58656842e1b2ddbe07f43f56b10a60f2ba5826164910968f5933e5178af75 + # via ray multidict==6.6.3 \ --hash=sha256:02fd8f32d403a6ff13864b0851f1f523d4c988051eea0471d4f1fd8010f11134 \ --hash=sha256:04cbcce84f63b9af41bad04a54d4cc4e60e90c35b9e6ccb130be2d75b71f8c17 \ @@ -2692,6 +2756,7 @@ packaging==24.2 \ # nbconvert # pandas-gbq # pytest + # ray # scikit-image # snowflake-connector-python # sphinx @@ -3094,6 +3159,7 @@ protobuf==4.25.8 \ # proto-plus # pymilvus # qdrant-client + # ray # substrait psutil==5.9.0 \ --hash=sha256:072664401ae6e7c1bfb878c65d7282d4b4391f1bc9a56d5e03b5a490403271b5 \ @@ -3967,6 +4033,7 @@ pyyaml==6.0.2 \ # jupyter-events # kubernetes # pre-commit + # ray # responses # transformers # uvicorn @@ -4071,6 +4138,32 @@ qdrant-client==1.15.1 \ --hash=sha256:2b975099b378382f6ca1cfb43f0d59e541be6e16a5892f282a4b8de7eff5cb63 \ --hash=sha256:631f1f3caebfad0fd0c1fba98f41be81d9962b7bf3ca653bed3b727c0e0cbe0e # via feast (setup.py) +ray==2.48.0 \ + --hash=sha256:24a70f416ec0be14b975f160044805ccb48cc6bc50de632983eb8f0a8e16682b \ + --hash=sha256:25e4b79fcc8f849d72db1acc4f03f37008c5c0b745df63d8a30cd35676b6545e \ + --hash=sha256:33bda4753ad0acd2b524c9158089d43486cd44cc59fe970466435bc2968fde2d \ + --hash=sha256:46d4b42a58492dec79caad2d562344689a4f99a828aeea811a0cd2cd653553ef \ + --hash=sha256:4b9b92ac29635f555ef341347d9a63dbf02b7d946347239af3c09e364bc45cf8 \ + --hash=sha256:5742b72a514afe5d60f41330200cd508376e16c650f6962e62337aa482d6a0c6 \ + --hash=sha256:5a6f57126eac9dd3286289e07e91e87b054792f9698b6f7ccab88b624816b542 \ + --hash=sha256:622e6bcdb78d98040d87bea94e65d0bb6ccc0ae1b43294c6bd69f542bf28e092 \ + --hash=sha256:649ed9442dc2d39135c593b6cf0c38e8355170b92672365ab7a3cbc958c42634 \ + --hash=sha256:6ca2b9ce45ad360cbe2996982fb22691ecfe6553ec8f97a2548295f0f96aac78 \ + --hash=sha256:8de799f3b0896f48d306d5e4a04fc6037a08c495d45f9c79935344e5693e3cf8 \ + --hash=sha256:a42ed3b640f4b599a3fc8067c83ee60497c0f03d070d7a7df02a388fa17a546b \ + --hash=sha256:a45de103173c2ed6a0defd7a2919a2bbe531fd5bf6619860cd111ca4a16e9288 \ + --hash=sha256:a7a6d830d9dc5ae8bb156fcde9a1adab7f4edb004f03918a724d885eceb8264d \ + --hash=sha256:b37a0fea4094f95d5926b1d7245abd70deb62882da3d738f9f9b76214894745c \ + --hash=sha256:b427dead5f8ad96d494d3a006d92ea2f8f16be5e6303b590e12234b37f96fbc2 \ + --hash=sha256:b94500fe2d17e491fe2e9bd4a3bf62df217e21a8f2845033c353d4d2ea240f73 \ + --hash=sha256:be45690565907c4aa035d753d82f6ff892d1e6830057b67399542a035b3682f0 \ + --hash=sha256:cfb48c10371c267fdcf7f4ae359cab706f068178b9c65317ead011972f2c0bf3 \ + --hash=sha256:e15fdffa6b60d5729f6025691396b8a01dc3461ba19dc92bba354ec1813ed6b1 \ + --hash=sha256:e6543fb3450a71862cfa1e7a666025d751f81685602cc87d499072ccd839507d \ + --hash=sha256:ea9d7739ae8f6db48b226bbc2a592640f7f2b6d854ff73d0305774b98fa9fb11 \ + --hash=sha256:f1cf33d260316f92f77558185f1c36fc35506d76ee7fdfed9f5b70f9c4bdba7f \ + --hash=sha256:f820950bc44d7b000c223342f5c800c9c08e7fd89524201125388ea211caad1a + # via feast (setup.py) redis==4.6.0 \ --hash=sha256:585dc516b9eb042a619ef0a39c3d7d55fe81bdb4df09a52c9cdde0d07bf1aa7d \ --hash=sha256:e2b03db868160ee4591de3cb90d40ebb50a90dd302138775937f6a42b7ed183c @@ -4194,6 +4287,7 @@ requests==2.32.4 \ # moto # msal # python-keycloak + # ray # requests-oauthlib # requests-toolbelt # responses diff --git a/sdk/python/tests/doctest/test_all.py b/sdk/python/tests/doctest/test_all.py index de032264e6d..8a85a72ab45 100644 --- a/sdk/python/tests/doctest/test_all.py +++ b/sdk/python/tests/doctest/test_all.py @@ -71,48 +71,53 @@ def test_docstrings(): next_packages = [] for package in current_packages: - for _, name, is_pkg in pkgutil.walk_packages(package.__path__): - if name in FILES_TO_IGNORE: - continue - - full_name = package.__name__ + "." + name - try: - # https://github.com/feast-dev/feast/issues/5088 - if "ikv" not in full_name and "milvus" not in full_name: - temp_module = importlib.import_module(full_name) - if is_pkg: - next_packages.append(temp_module) - except ModuleNotFoundError: - pass - - # Retrieve the setup and teardown functions defined in this file. - relative_path_from_feast = full_name.split(".", 1)[1] - function_suffix = relative_path_from_feast.replace(".", "_") - setup_function_name = "setup_" + function_suffix - teardown_function_name = "teardown_" + function_suffix - setup_function = globals().get(setup_function_name) - teardown_function = globals().get(teardown_function_name) - - # Execute the test with setup and teardown functions. - try: - if setup_function: - setup_function() - - test_suite = doctest.DocTestSuite( - temp_module, - optionflags=doctest.ELLIPSIS, - ) - if test_suite.countTestCases() > 0: - result = unittest.TextTestRunner(sys.stdout).run(test_suite) - if not result.wasSuccessful(): - successful = False - failed_cases.append(result.failures) - except Exception as e: - successful = False - failed_cases.append((full_name, str(e) + traceback.format_exc())) - finally: - if teardown_function: - teardown_function() + try: + for _, name, is_pkg in pkgutil.walk_packages(package.__path__): + if name in FILES_TO_IGNORE: + continue + + full_name = package.__name__ + "." + name + try: + # https://github.com/feast-dev/feast/issues/5088 + if "ikv" not in full_name and "milvus" not in full_name: + temp_module = importlib.import_module(full_name) + if is_pkg: + next_packages.append(temp_module) + except ModuleNotFoundError: + pass + + # Retrieve the setup and teardown functions defined in this file. + relative_path_from_feast = full_name.split(".", 1)[1] + function_suffix = relative_path_from_feast.replace(".", "_") + setup_function_name = "setup_" + function_suffix + teardown_function_name = "teardown_" + function_suffix + setup_function = globals().get(setup_function_name) + teardown_function = globals().get(teardown_function_name) + + # Execute the test with setup and teardown functions. + try: + if setup_function: + setup_function() + + test_suite = doctest.DocTestSuite( + temp_module, + optionflags=doctest.ELLIPSIS, + ) + if test_suite.countTestCases() > 0: + result = unittest.TextTestRunner(sys.stdout).run(test_suite) + if not result.wasSuccessful(): + successful = False + failed_cases.append(result.failures) + except Exception as e: + successful = False + failed_cases.append( + (full_name, str(e) + traceback.format_exc()) + ) + finally: + if teardown_function: + teardown_function() + except DeprecationWarning: # To catch ray.tune.automl deprecation + pass current_packages = next_packages diff --git a/sdk/python/tests/integration/__init__.py b/sdk/python/tests/integration/__init__.py new file mode 100644 index 00000000000..c66cd71b7e1 --- /dev/null +++ b/sdk/python/tests/integration/__init__.py @@ -0,0 +1 @@ +"""Integration tests package.""" diff --git a/sdk/python/tests/integration/compute_engines/__init__.py b/sdk/python/tests/integration/compute_engines/__init__.py new file mode 100644 index 00000000000..6a582448b68 --- /dev/null +++ b/sdk/python/tests/integration/compute_engines/__init__.py @@ -0,0 +1 @@ +"""Compute engines integration tests package.""" diff --git a/sdk/python/tests/integration/compute_engines/ray_compute/__init__.py b/sdk/python/tests/integration/compute_engines/ray_compute/__init__.py new file mode 100644 index 00000000000..7938db59420 --- /dev/null +++ b/sdk/python/tests/integration/compute_engines/ray_compute/__init__.py @@ -0,0 +1 @@ +"""Ray compute engine integration tests.""" diff --git a/sdk/python/tests/integration/compute_engines/ray_compute/conftest.py b/sdk/python/tests/integration/compute_engines/ray_compute/conftest.py new file mode 100644 index 00000000000..885b1555ec7 --- /dev/null +++ b/sdk/python/tests/integration/compute_engines/ray_compute/conftest.py @@ -0,0 +1,26 @@ +"""Pytest configuration and fixtures for Ray compute engine tests. + +This module exposes fixtures from ray_shared_utils.py so they can be +auto-discovered by pytest. +""" + +from tests.integration.compute_engines.ray_compute.ray_shared_utils import ( + entity_df, + feature_dataset, + ray_environment, + temp_dir, +) + + +def pytest_configure(config): + """Configure pytest for Ray tests.""" + config.addinivalue_line("markers", "ray: mark test as requiring Ray compute engine") + + +__all__ = [ + "entity_df", + "feature_dataset", + "ray_environment", + "temp_dir", + "pytest_configure", +] diff --git a/sdk/python/tests/integration/compute_engines/ray_compute/ray_shared_utils.py b/sdk/python/tests/integration/compute_engines/ray_compute/ray_shared_utils.py new file mode 100644 index 00000000000..9e9aabc4f90 --- /dev/null +++ b/sdk/python/tests/integration/compute_engines/ray_compute/ray_shared_utils.py @@ -0,0 +1,177 @@ +"""Shared fixtures and utilities for Ray compute engine tests.""" + +import os +import tempfile +import time +import uuid +from datetime import timedelta +from typing import Generator + +import pandas as pd +import pytest +import ray + +from feast import Entity, FileSource +from feast.data_source import DataSource +from feast.utils import _utc_now +from tests.integration.feature_repos.repo_configuration import ( + construct_test_environment, +) + +from .repo_configuration import get_ray_compute_engine_test_config + +now = _utc_now().replace(microsecond=0, second=0, minute=0) +today = now.replace(hour=0, minute=0, second=0, microsecond=0) + + +def get_test_date_range(days_back: int = 7) -> tuple: + """Get a standard test date range (start_date, end_date) for testing.""" + end_date = now + start_date = now - timedelta(days=days_back) + return start_date, end_date + + +driver = Entity( + name="driver_id", + description="driver id", +) + + +def create_feature_dataset(ray_environment) -> DataSource: + """Create a test dataset for feature views.""" + yesterday = today - timedelta(days=1) + last_week = today - timedelta(days=7) + df = pd.DataFrame( + [ + { + "driver_id": 1001, + "event_timestamp": yesterday, + "created": now - timedelta(hours=2), + "conv_rate": 0.8, + "acc_rate": 0.5, + "avg_daily_trips": 15, + }, + { + "driver_id": 1001, + "event_timestamp": last_week, + "created": now - timedelta(hours=3), + "conv_rate": 0.75, + "acc_rate": 0.9, + "avg_daily_trips": 14, + }, + { + "driver_id": 1002, + "event_timestamp": yesterday, + "created": now - timedelta(hours=2), + "conv_rate": 0.7, + "acc_rate": 0.4, + "avg_daily_trips": 12, + }, + { + "driver_id": 1002, + "event_timestamp": yesterday - timedelta(days=1), + "created": now - timedelta(hours=2), + "conv_rate": 0.3, + "acc_rate": 0.6, + "avg_daily_trips": 12, + }, + ] + ) + ds = ray_environment.data_source_creator.create_data_source( + df, + ray_environment.feature_store.project, + timestamp_field="event_timestamp", + created_timestamp_column="created", + ) + return ds + + +def create_entity_df() -> pd.DataFrame: + """Create entity dataframe for testing.""" + entity_df = pd.DataFrame( + [ + {"driver_id": 1001, "event_timestamp": today}, + {"driver_id": 1002, "event_timestamp": today}, + ] + ) + return entity_df + + +def create_unique_sink_source(temp_dir: str, base_name: str) -> FileSource: + """Create a unique sink source to avoid path collisions during parallel test execution.""" + timestamp = int(time.time() * 1000) + process_id = os.getpid() + unique_id = str(uuid.uuid4())[:8] + + # Create a unique directory for this sink - Ray needs directory paths for materialization + sink_dir = os.path.join( + temp_dir, f"{base_name}_{timestamp}_{process_id}_{unique_id}" + ) + os.makedirs(sink_dir, exist_ok=True) + + return FileSource( + name=f"{base_name}_sink_source", + path=sink_dir, # Use directory path - Ray will create files inside + timestamp_field="event_timestamp", + created_timestamp_column="created", + ) + + +def cleanup_ray_environment(ray_environment): + """Safely cleanup Ray environment and resources.""" + try: + ray_environment.teardown() + except Exception as e: + print(f"Warning: Ray environment teardown failed: {e}") + + # Ensure Ray is shut down completely + try: + if ray.is_initialized(): + ray.shutdown() + time.sleep(0.2) # Brief pause to ensure clean shutdown + except Exception as e: + print(f"Warning: Ray shutdown failed: {e}") + + +def create_ray_environment(): + """Create Ray test environment using the standardized config.""" + ray_config = get_ray_compute_engine_test_config() + ray_environment = construct_test_environment( + ray_config, None, entity_key_serialization_version=3 + ) + ray_environment.setup() + return ray_environment + + +@pytest.fixture(scope="function") +def ray_environment() -> Generator: + """Pytest fixture to provide a Ray environment for tests with automatic cleanup.""" + try: + if ray.is_initialized(): + ray.shutdown() + time.sleep(0.2) + except Exception: + pass + + environment = create_ray_environment() + yield environment + cleanup_ray_environment(environment) + + +@pytest.fixture +def feature_dataset(ray_environment) -> DataSource: + """Fixture that provides a feature dataset for testing.""" + return create_feature_dataset(ray_environment) + + +@pytest.fixture +def entity_df() -> pd.DataFrame: + """Fixture that provides an entity dataframe for testing.""" + return create_entity_df() + + +@pytest.fixture +def temp_dir() -> Generator[str, None, None]: + """Fixture that provides a temporary directory for test artifacts.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield temp_dir diff --git a/sdk/python/tests/integration/compute_engines/ray_compute/repo_configuration.py b/sdk/python/tests/integration/compute_engines/ray_compute/repo_configuration.py new file mode 100644 index 00000000000..37d0d020ccd --- /dev/null +++ b/sdk/python/tests/integration/compute_engines/ray_compute/repo_configuration.py @@ -0,0 +1,25 @@ +"""Test configuration for Ray compute engine integration tests.""" + +from feast.infra.offline_stores.contrib.ray_repo_configuration import ( + RayDataSourceCreator, +) +from tests.integration.feature_repos.integration_test_repo_config import ( + IntegrationTestRepoConfig, +) +from tests.integration.feature_repos.universal.online_store.redis import ( + RedisOnlineStoreCreator, +) + + +def get_ray_compute_engine_test_config() -> IntegrationTestRepoConfig: + """Get test configuration for Ray compute engine.""" + return IntegrationTestRepoConfig( + provider="local", + online_store_creator=RedisOnlineStoreCreator, + offline_store_creator=RayDataSourceCreator, + batch_engine={ + "type": "ray.engine", + "max_workers": 1, + "enable_optimization": True, + }, + ) diff --git a/sdk/python/tests/integration/compute_engines/ray_compute/test_compute.py b/sdk/python/tests/integration/compute_engines/ray_compute/test_compute.py new file mode 100644 index 00000000000..e7060b4a756 --- /dev/null +++ b/sdk/python/tests/integration/compute_engines/ray_compute/test_compute.py @@ -0,0 +1,171 @@ +from datetime import timedelta +from typing import cast +from unittest.mock import MagicMock + +import pandas as pd +import pytest +from tqdm import tqdm + +from feast import BatchFeatureView, Field +from feast.aggregation import Aggregation +from feast.infra.common.materialization_job import ( + MaterializationJobStatus, + MaterializationTask, +) +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.ray.compute import RayComputeEngine +from feast.infra.compute_engines.ray.config import RayComputeEngineConfig +from feast.infra.compute_engines.ray.job import RayDAGRetrievalJob +from feast.infra.offline_stores.contrib.ray_offline_store.ray import ( + RayOfflineStore, +) +from feast.types import Float32, Int32, Int64 +from tests.integration.compute_engines.ray_compute.ray_shared_utils import ( + driver, + now, +) + + +@pytest.mark.integration +@pytest.mark.xdist_group(name="ray") +def test_ray_compute_engine_get_historical_features( + ray_environment, feature_dataset, entity_df +): + """Test Ray compute engine historical feature retrieval.""" + fs = ray_environment.feature_store + registry = fs.registry + + def transform_feature(df: pd.DataFrame) -> pd.DataFrame: + df["sum_conv_rate"] = df["sum_conv_rate"] * 2 + df["avg_acc_rate"] = df["avg_acc_rate"] * 2 + return df + + driver_stats_fv = BatchFeatureView( + name="driver_hourly_stats", + entities=[driver], + mode="pandas", + aggregations=[ + Aggregation(column="conv_rate", function="sum"), + Aggregation(column="acc_rate", function="avg"), + ], + udf=transform_feature, + udf_string="transform_feature", + ttl=timedelta(days=3), + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + Field(name="driver_id", dtype=Int32), + ], + online=False, + offline=False, + source=feature_dataset, + ) + + fs.apply([driver, driver_stats_fv]) + + # Build retrieval task + task = HistoricalRetrievalTask( + project=ray_environment.project, + entity_df=entity_df, + feature_view=driver_stats_fv, + full_feature_name=False, + registry=registry, + ) + engine = RayComputeEngine( + repo_config=ray_environment.config, + offline_store=RayOfflineStore(), + online_store=MagicMock(), + ) + + ray_dag_retrieval_job = engine.get_historical_features(registry, task) + ray_dataset = cast(RayDAGRetrievalJob, ray_dag_retrieval_job).to_ray_dataset() + df_out = ray_dataset.to_pandas().sort_values("driver_id") + + assert df_out.driver_id.to_list() == [1001, 1002] + assert abs(df_out["sum_conv_rate"].to_list()[0] - 1.6) < 1e-6 + assert abs(df_out["sum_conv_rate"].to_list()[1] - 2.0) < 1e-6 + assert abs(df_out["avg_acc_rate"].to_list()[0] - 1.0) < 1e-6 + assert abs(df_out["avg_acc_rate"].to_list()[1] - 1.0) < 1e-6 + + +@pytest.mark.integration +@pytest.mark.xdist_group(name="ray") +def test_ray_compute_engine_materialize(ray_environment, feature_dataset): + """Test Ray compute engine materialization.""" + fs = ray_environment.feature_store + registry = fs.registry + + def transform_feature(df: pd.DataFrame) -> pd.DataFrame: + df["sum_conv_rate"] = df["sum_conv_rate"] * 2 + df["avg_acc_rate"] = df["avg_acc_rate"] * 2 + return df + + driver_stats_fv = BatchFeatureView( + name="driver_hourly_stats", + entities=[driver], + mode="pandas", + aggregations=[ + Aggregation(column="conv_rate", function="sum"), + Aggregation(column="acc_rate", function="avg"), + ], + udf=transform_feature, + udf_string="transform_feature", + ttl=timedelta(days=3), + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + Field(name="driver_id", dtype=Int32), + ], + online=True, + offline=False, + source=feature_dataset, + ) + + def tqdm_builder(length): + return tqdm(length, ncols=100) + + fs.apply([driver, driver_stats_fv]) + + task = MaterializationTask( + project=ray_environment.project, + feature_view=driver_stats_fv, + start_time=now - timedelta(days=2), + end_time=now, + tqdm_builder=tqdm_builder, + ) + + engine = RayComputeEngine( + repo_config=ray_environment.config, + offline_store=RayOfflineStore(), + online_store=MagicMock(), + ) + + ray_materialize_jobs = engine.materialize(registry, task) + + assert len(ray_materialize_jobs) == 1 + assert ray_materialize_jobs[0].status() == MaterializationJobStatus.SUCCEEDED + + +@pytest.mark.integration +@pytest.mark.xdist_group(name="ray") +def test_ray_compute_engine_config(): + """Test Ray compute engine configuration.""" + config = RayComputeEngineConfig( + type="ray.engine", + ray_address="ray://localhost:10001", + broadcast_join_threshold_mb=200, + enable_distributed_joins=True, + max_parallelism_multiplier=4, + target_partition_size_mb=128, + window_size_for_joins="2H", + max_workers=4, + enable_optimization=True, + execution_timeout_seconds=3600, + ) + + assert config.type == "ray.engine" + assert config.ray_address == "ray://localhost:10001" + assert config.broadcast_join_threshold_mb == 200 + assert config.window_size_timedelta == timedelta(hours=2) diff --git a/sdk/python/tests/integration/compute_engines/ray_compute/test_source_feature_views.py b/sdk/python/tests/integration/compute_engines/ray_compute/test_source_feature_views.py new file mode 100644 index 00000000000..7d8f23e1bf6 --- /dev/null +++ b/sdk/python/tests/integration/compute_engines/ray_compute/test_source_feature_views.py @@ -0,0 +1,308 @@ +import time +from datetime import timedelta + +import pandas as pd +import pytest + +from feast import FeatureView, Field +from feast.data_source import DataSource +from feast.infra.common.materialization_job import ( + MaterializationJobStatus, +) +from feast.types import Float32, Int32, Int64 +from tests.integration.compute_engines.ray_compute.ray_shared_utils import ( + create_entity_df, + create_feature_dataset, + create_unique_sink_source, + driver, + now, + today, +) + + +def create_base_feature_view(source: DataSource, name_suffix: str = "") -> FeatureView: + """Create a base feature view with data source.""" + return FeatureView( + name=f"base_driver_stats{name_suffix}", + entities=[driver], + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + Field(name="driver_id", dtype=Int32), + ], + online=True, + offline=True, + source=source, + ) + + +def create_derived_feature_view( + base_fv: FeatureView, sink_source: DataSource, name_suffix: str = "" +) -> FeatureView: + """Create a derived feature view that uses another feature view as source. + Note: This creates a regular FeatureView with another FeatureView as source. + """ + return FeatureView( + name=f"derived_driver_stats{name_suffix}", + entities=[driver], + schema=[ + Field(name="conv_rate", dtype=Float32), # Same feature names as source + Field(name="acc_rate", dtype=Float32), # Same feature names as source + Field(name="avg_daily_trips", dtype=Int64), # Same feature names as source + Field(name="driver_id", dtype=Int32), + ], + online=True, + offline=True, + source=base_fv, + sink_source=sink_source, + ) + + +@pytest.mark.integration +@pytest.mark.xdist_group(name="ray") +def test_ray_compute_engine_single_source_feature_view(ray_environment, temp_dir): + """Test Ray compute engine with a single source feature view.""" + fs = ray_environment.feature_store + + data_source = create_feature_dataset(ray_environment) + base_fv = create_base_feature_view(data_source, "_single") + sink_source = create_unique_sink_source(temp_dir, "derived_sink_single") + derived_fv = create_derived_feature_view(base_fv, sink_source, "_single") + fs.apply([driver, base_fv, derived_fv]) + + entity_df = create_entity_df() + job = fs.get_historical_features( + entity_df=entity_df, + features=[ + f"{base_fv.name}:conv_rate", + f"{base_fv.name}:acc_rate", + f"{derived_fv.name}:conv_rate", + f"{derived_fv.name}:acc_rate", + ], + full_feature_names=True, + ) + result_df = job.to_df() + assert len(result_df) == 2 + assert f"{base_fv.name}__conv_rate" in result_df.columns + assert f"{base_fv.name}__acc_rate" in result_df.columns + assert f"{derived_fv.name}__conv_rate" in result_df.columns + assert f"{derived_fv.name}__acc_rate" in result_df.columns + + +@pytest.mark.integration +@pytest.mark.xdist_group(name="ray") +def test_ray_compute_engine_materialization_with_source_feature_views( + ray_environment, temp_dir +): + """Test Ray compute engine materialization with source feature views.""" + fs = ray_environment.feature_store + data_source = create_feature_dataset(ray_environment) + base_fv = create_base_feature_view(data_source, "_materialize") + + sink_source = create_unique_sink_source(temp_dir, "derived_sink") + derived_fv = create_derived_feature_view(base_fv, sink_source, "_materialize") + + fs.apply([driver, base_fv, derived_fv]) + start_date = today - timedelta(days=7) + end_date = today + + # Materialize only the derived feature view - compute engine handles base dependencies + derived_job = fs.materialize( + start_date=start_date, + end_date=end_date, + feature_views=[derived_fv.name], + ) + + if derived_job is not None: + assert derived_job.status == MaterializationJobStatus.SUCCEEDED + else: + print("Materialization completed synchronously (no job object returned)") + + +@pytest.mark.integration +@pytest.mark.xdist_group(name="ray") +def test_ray_compute_engine_cycle_detection(ray_environment, temp_dir): + """Test Ray compute engine cycle detection in feature view dependencies.""" + fs = ray_environment.feature_store + data_source = create_feature_dataset(ray_environment) + sink_source1 = create_unique_sink_source(temp_dir, "cycle_sink1") + sink_source2 = create_unique_sink_source(temp_dir, "cycle_sink2") + + fv1 = FeatureView( + name="cycle_fv1", + entities=[driver], + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="driver_id", dtype=Int32), + ], + source=data_source, + online=True, + offline=True, + ) + + fv2 = FeatureView( + name="cycle_fv2", + entities=[driver], + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="driver_id", dtype=Int32), + ], + source=fv1, + sink_source=sink_source1, + online=True, + offline=True, + ) + + fv3 = FeatureView( + name="cycle_fv3", + entities=[driver], + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="driver_id", dtype=Int32), + ], + source=fv2, + sink_source=sink_source2, + online=True, + offline=True, + ) + + # Apply feature views (this should work without cycles) + fs.apply([driver, fv1, fv2, fv3]) + + entity_df = create_entity_df() + + job = fs.get_historical_features( + entity_df=entity_df, + features=[ + f"{fv1.name}:conv_rate", + f"{fv2.name}:conv_rate", + f"{fv3.name}:conv_rate", + ], + full_feature_names=True, + ) + + result_df = job.to_df() + + assert len(result_df) == 2 + assert f"{fv1.name}__conv_rate" in result_df.columns + assert f"{fv2.name}__conv_rate" in result_df.columns + assert f"{fv3.name}__conv_rate" in result_df.columns + + +@pytest.mark.integration +@pytest.mark.xdist_group(name="ray") +def test_ray_compute_engine_error_handling(ray_environment, temp_dir): + """Test Ray compute engine error handling with invalid source feature views.""" + fs = ray_environment.feature_store + data_source = create_feature_dataset(ray_environment) + base_fv = create_base_feature_view(data_source, "_error") + + # Test 1: Regular FeatureView with FeatureView source but no sink_source should fail + with pytest.raises( + ValueError, match="Derived FeatureView must specify `sink_source`" + ): + FeatureView( + name="invalid_fv", + entities=[driver], + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="driver_id", dtype=Int32), + ], + source=base_fv, + online=True, + offline=True, + ) + + # Test 2: Valid FeatureView with sink_source should work + sink_source = create_unique_sink_source(temp_dir, "valid_sink") + valid_fv = FeatureView( + name="valid_fv", + entities=[driver], + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="driver_id", dtype=Int32), + ], # Use same feature name as source + source=base_fv, + sink_source=sink_source, + online=True, + offline=True, + ) + + fs.apply([driver, base_fv, valid_fv]) + entity_df = create_entity_df() + job = fs.get_historical_features( + entity_df=entity_df, + features=[ + f"{base_fv.name}:conv_rate", + f"{valid_fv.name}:conv_rate", # Use same feature name as source + ], + full_feature_names=True, + ) + + result_df = job.to_df() + assert len(result_df) == 2 + assert f"{base_fv.name}__conv_rate" in result_df.columns + assert f"{valid_fv.name}__conv_rate" in result_df.columns + assert result_df[f"{base_fv.name}__conv_rate"].notna().all() + assert result_df[f"{valid_fv.name}__conv_rate"].notna().all() + + +@pytest.mark.integration +@pytest.mark.xdist_group(name="ray") +def test_ray_compute_engine_performance_with_source_feature_views( + ray_environment, temp_dir +): + """Test Ray compute engine performance with source feature views.""" + fs = ray_environment.feature_store + large_df = pd.DataFrame() + for i in range(1000): + large_df = pd.concat( + [ + large_df, + pd.DataFrame( + { + "driver_id": [1000 + i], + "event_timestamp": [today - timedelta(days=i % 30)], + "created": [now - timedelta(hours=i % 24)], + "conv_rate": [0.5 + (i % 10) * 0.05], + "acc_rate": [0.6 + (i % 10) * 0.04], + "avg_daily_trips": [10 + i % 20], + } + ), + ] + ) + data_source = ray_environment.data_source_creator.create_data_source( + large_df, + fs.project, + timestamp_field="event_timestamp", + created_timestamp_column="created", + ) + base_fv = create_base_feature_view(data_source, "_perf") + sink_source1 = create_unique_sink_source(temp_dir, "perf_sink1") + derived_fv1 = create_derived_feature_view(base_fv, sink_source1, "_perf1") + sink_source2 = create_unique_sink_source(temp_dir, "perf_sink2") + derived_fv2 = create_derived_feature_view(derived_fv1, sink_source2, "_perf2") + fs.apply([driver, base_fv, derived_fv1, derived_fv2]) + + large_entity_df = pd.DataFrame( + { + "driver_id": [1000 + i for i in range(100)], + "event_timestamp": [today] * 100, + } + ) + start_time = time.time() + job = fs.get_historical_features( + entity_df=large_entity_df, + features=[ + f"{base_fv.name}:conv_rate", + f"{derived_fv1.name}:conv_rate", + ], + full_feature_names=True, + ) + result_df = job.to_df() + end_time = time.time() + assert len(result_df) == 100 + assert f"{base_fv.name}__conv_rate" in result_df.columns + assert f"{derived_fv1.name}__conv_rate" in result_df.columns + assert end_time - start_time < 60 diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index 24e611c4f33..89a13df69ed 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -27,6 +27,9 @@ FeatureLoggingConfig, ) from feast.infra.feature_servers.local_process.config import LocalFeatureServerConfig +from feast.infra.offline_stores.contrib.ray_repo_configuration import ( + RayDataSourceCreator, +) from feast.permissions.action import AuthzedAction from feast.permissions.auth_model import OidcClientAuthConfig from feast.permissions.permission import Permission @@ -137,6 +140,7 @@ ("local", RemoteOfflineStoreDataSourceCreator), ("local", RemoteOfflineOidcAuthStoreDataSourceCreator), ("local", RemoteOfflineTlsStoreDataSourceCreator), + ("local", RayDataSourceCreator), ] if os.getenv("FEAST_IS_LOCAL_TEST", "False") == "True": diff --git a/sdk/python/tests/unit/__init__.py b/sdk/python/tests/unit/__init__.py new file mode 100644 index 00000000000..ea3f8b923c2 --- /dev/null +++ b/sdk/python/tests/unit/__init__.py @@ -0,0 +1 @@ +"""Unit tests package.""" diff --git a/sdk/python/tests/unit/infra/compute_engines/__init__.py b/sdk/python/tests/unit/infra/compute_engines/__init__.py new file mode 100644 index 00000000000..b1587145566 --- /dev/null +++ b/sdk/python/tests/unit/infra/compute_engines/__init__.py @@ -0,0 +1 @@ +"""Compute engines unit tests package.""" diff --git a/sdk/python/tests/unit/infra/compute_engines/ray_compute/__init__.py b/sdk/python/tests/unit/infra/compute_engines/ray_compute/__init__.py new file mode 100644 index 00000000000..2734c36c704 --- /dev/null +++ b/sdk/python/tests/unit/infra/compute_engines/ray_compute/__init__.py @@ -0,0 +1 @@ +"""Ray compute engine unit tests.""" diff --git a/sdk/python/tests/unit/infra/compute_engines/ray_compute/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/ray_compute/test_nodes.py new file mode 100644 index 00000000000..e8c40d43099 --- /dev/null +++ b/sdk/python/tests/unit/infra/compute_engines/ray_compute/test_nodes.py @@ -0,0 +1,319 @@ +from datetime import datetime, timedelta + +import pandas as pd +import pytest +import ray + +from feast.aggregation import Aggregation +from feast.infra.compute_engines.dag.context import ColumnInfo +from feast.infra.compute_engines.dag.model import DAGFormat +from feast.infra.compute_engines.dag.node import DAGNode +from feast.infra.compute_engines.dag.value import DAGValue +from feast.infra.compute_engines.ray.config import RayComputeEngineConfig +from feast.infra.compute_engines.ray.nodes import ( + RayAggregationNode, + RayDedupNode, + RayFilterNode, + RayJoinNode, + RayReadNode, + RayTransformationNode, +) + + +class DummyInputNode(DAGNode): + def __init__(self, name, output): + super().__init__(name) + self._output = output + + def execute(self, context): + return self._output + + +class DummyFeatureView: + name = "dummy" + online = False + offline = False + + +class DummySource: + pass + + +class DummyRetrievalJob: + def __init__(self, ray_dataset): + self._ray_dataset = ray_dataset + + def to_ray_dataset(self): + return self._ray_dataset + + +@pytest.fixture(scope="session") +def ray_session(): + """Initialize Ray session for testing.""" + if not ray.is_initialized(): + ray.init(num_cpus=2, ignore_reinit_error=True, include_dashboard=False) + yield ray + ray.shutdown() + + +@pytest.fixture +def ray_config(): + """Create Ray compute engine configuration for testing.""" + return RayComputeEngineConfig( + type="ray.engine", + max_workers=2, + enable_optimization=True, + broadcast_join_threshold_mb=50, + target_partition_size_mb=32, + ) + + +@pytest.fixture +def mock_context(): + class DummyOfflineStore: + def offline_write_batch(self, *args, **kwargs): + pass + + class DummyContext: + def __init__(self): + self.registry = None + self.store = None + self.project = "test_project" + self.entity_data = None + self.config = None + self.node_outputs = {} + self.offline_store = DummyOfflineStore() + + return DummyContext() + + +@pytest.fixture +def sample_data(): + """Create sample data for testing.""" + return pd.DataFrame( + [ + { + "driver_id": 1001, + "event_timestamp": datetime.now() - timedelta(hours=1), + "created": datetime.now() - timedelta(hours=2), + "conv_rate": 0.8, + "acc_rate": 0.5, + "avg_daily_trips": 15, + }, + { + "driver_id": 1002, + "event_timestamp": datetime.now() - timedelta(hours=2), + "created": datetime.now() - timedelta(hours=3), + "conv_rate": 0.7, + "acc_rate": 0.4, + "avg_daily_trips": 12, + }, + { + "driver_id": 1001, + "event_timestamp": datetime.now() - timedelta(hours=3), + "created": datetime.now() - timedelta(hours=4), + "conv_rate": 0.75, + "acc_rate": 0.9, + "avg_daily_trips": 14, + }, + ] + ) + + +@pytest.fixture +def column_info(): + """Create a sample ColumnInfo for testing Ray nodes.""" + return ColumnInfo( + join_keys=["driver_id"], + feature_cols=["conv_rate", "acc_rate", "avg_daily_trips"], + ts_col="event_timestamp", + created_ts_col="created", + field_mapping=None, + ) + + +def test_ray_read_node(ray_session, ray_config, mock_context, sample_data, column_info): + """Test RayReadNode functionality.""" + ray_dataset = ray.data.from_pandas(sample_data) + mock_source = DummySource() + node = RayReadNode( + name="read", + source=mock_source, + column_info=column_info, + config=ray_config, + ) + mock_context.registry = None + mock_context.store = None + mock_context.offline_store = None + mock_retrieval_job = DummyRetrievalJob(ray_dataset) + import feast.infra.compute_engines.ray.nodes as ray_nodes + + ray_nodes.create_offline_store_retrieval_job = lambda **kwargs: mock_retrieval_job + result = node.execute(mock_context) + assert isinstance(result, DAGValue) + assert result.format == DAGFormat.RAY + result_df = result.data.to_pandas() + assert len(result_df) == 3 + assert "driver_id" in result_df.columns + assert "conv_rate" in result_df.columns + + +def test_ray_aggregation_node( + ray_session, ray_config, mock_context, sample_data, column_info +): + """Test RayAggregationNode functionality.""" + ray_dataset = ray.data.from_pandas(sample_data) + input_value = DAGValue(data=ray_dataset, format=DAGFormat.RAY) + dummy_node = DummyInputNode("input_node", input_value) + node = RayAggregationNode( + name="aggregation", + aggregations=[ + Aggregation(column="conv_rate", function="sum"), + Aggregation(column="acc_rate", function="avg"), + ], + group_by_keys=["driver_id"], + timestamp_col="event_timestamp", + config=ray_config, + ) + node.add_input(dummy_node) + mock_context.node_outputs = {"input_node": input_value} + result = node.execute(mock_context) + assert isinstance(result, DAGValue) + assert result.format == DAGFormat.RAY + result_df = result.data.to_pandas() + assert len(result_df) == 2 + assert "driver_id" in result_df.columns + assert "sum_conv_rate" in result_df.columns + assert "avg_acc_rate" in result_df.columns + + +def test_ray_join_node(ray_session, ray_config, mock_context, sample_data, column_info): + """Test RayJoinNode functionality.""" + entity_data = pd.DataFrame( + [ + {"driver_id": 1001, "event_timestamp": datetime.now()}, + {"driver_id": 1002, "event_timestamp": datetime.now()}, + ] + ) + feature_dataset = ray.data.from_pandas(sample_data) + feature_value = DAGValue(data=feature_dataset, format=DAGFormat.RAY) + dummy_node = DummyInputNode("feature_node", feature_value) + node = RayJoinNode( + name="join", + column_info=column_info, + config=ray_config, + ) + node.add_input(dummy_node) + mock_context.node_outputs = {"feature_node": feature_value} + mock_context.entity_df = entity_data + result = node.execute(mock_context) + assert isinstance(result, DAGValue) + assert result.format == DAGFormat.RAY + result_df = result.data.to_pandas() + assert len(result_df) >= 2 + assert "driver_id" in result_df.columns + + +def test_ray_transformation_node( + ray_session, ray_config, mock_context, sample_data, column_info +): + """Test RayTransformationNode functionality.""" + ray_dataset = ray.data.from_pandas(sample_data) + + def transform_feature(df: pd.DataFrame) -> pd.DataFrame: + df["conv_rate_doubled"] = df["conv_rate"] * 2 + return df + + input_value = DAGValue(data=ray_dataset, format=DAGFormat.RAY) + dummy_node = DummyInputNode("input_node", input_value) + node = RayTransformationNode( + name="transformation", + transformation=transform_feature, + config=ray_config, + ) + node.add_input(dummy_node) + mock_context.node_outputs = {"input_node": input_value} + result = node.execute(mock_context) + assert isinstance(result, DAGValue) + assert result.format == DAGFormat.RAY + result_df = result.data.to_pandas() + assert len(result_df) == 3 + assert "conv_rate_doubled" in result_df.columns + assert ( + result_df["conv_rate_doubled"].iloc[0] == sample_data["conv_rate"].iloc[0] * 2 + ) + + +def test_ray_filter_node( + ray_session, ray_config, mock_context, sample_data, column_info +): + """Test RayFilterNode functionality.""" + ray_dataset = ray.data.from_pandas(sample_data) + input_value = DAGValue(data=ray_dataset, format=DAGFormat.RAY) + dummy_node = DummyInputNode("input_node", input_value) + node = RayFilterNode( + name="filter", + column_info=column_info, + config=ray_config, + ttl=timedelta(hours=2), + filter_condition=None, + ) + node.add_input(dummy_node) + mock_context.node_outputs = {"input_node": input_value} + result = node.execute(mock_context) + assert isinstance(result, DAGValue) + assert result.format == DAGFormat.RAY + result_df = result.data.to_pandas() + assert len(result_df) <= 3 + assert "event_timestamp" in result_df.columns + + +def test_ray_dedup_node( + ray_session, ray_config, mock_context, sample_data, column_info +): + """Test RayDedupNode functionality.""" + duplicated_data = pd.concat([sample_data, sample_data.iloc[:1]], ignore_index=True) + ray_dataset = ray.data.from_pandas(duplicated_data) + input_value = DAGValue(data=ray_dataset, format=DAGFormat.RAY) + dummy_node = DummyInputNode("input_node", input_value) + node = RayDedupNode( + name="dedup", + column_info=column_info, + config=ray_config, + ) + node.add_input(dummy_node) + mock_context.node_outputs = {"input_node": input_value} + result = node.execute(mock_context) + assert isinstance(result, DAGValue) + assert result.format == DAGFormat.RAY + result_df = result.data.to_pandas() + assert len(result_df) == 2 # Should remove the duplicate row + assert "driver_id" in result_df.columns + + +def test_ray_config_validation(): + """Test Ray configuration validation.""" + # Test valid configuration + config = RayComputeEngineConfig( + type="ray.engine", + max_workers=4, + enable_optimization=True, + broadcast_join_threshold_mb=100, + target_partition_size_mb=64, + window_size_for_joins="30min", + ) + + assert config.type == "ray.engine" + assert config.max_workers == 4 + assert config.window_size_timedelta == timedelta(minutes=30) + + # Test window size parsing + config_hours = RayComputeEngineConfig(window_size_for_joins="2H") + assert config_hours.window_size_timedelta == timedelta(hours=2) + + config_seconds = RayComputeEngineConfig(window_size_for_joins="30s") + assert config_seconds.window_size_timedelta == timedelta(seconds=30) + + # Test invalid window size defaults to 1 hour + config_invalid = RayComputeEngineConfig(window_size_for_joins="invalid") + assert config_invalid.window_size_timedelta == timedelta(hours=1) diff --git a/setup.py b/setup.py index 033b2491c02..7545b0c19ae 100644 --- a/setup.py +++ b/setup.py @@ -180,6 +180,8 @@ "datasets>=3.6.0", ] +RAY_REQUIRED = ["ray>=2.47.0"] + CI_REQUIRED = ( [ "build", @@ -256,6 +258,7 @@ + CLICKHOUSE_REQUIRED + MCP_REQUIRED + RAG_REQUIRED + + RAY_REQUIRED ) MINIMAL_REQUIRED = ( GCP_REQUIRED @@ -358,6 +361,7 @@ "clickhouse": CLICKHOUSE_REQUIRED, "mcp": MCP_REQUIRED, "rag": RAG_REQUIRED, + "ray": RAY_REQUIRED, }, include_package_data=True, license="Apache",