From e725c41f90c27721395fb677860a4fae3d462115 Mon Sep 17 00:00:00 2001 From: ntkathole Date: Sat, 13 Sep 2025 09:55:43 +0530 Subject: [PATCH] fix: Fix hostname resolution for spark tests Signed-off-by: ntkathole --- .../infra/compute_engines/spark/utils.py | 3 ++ .../spark_offline_store/tests/data_source.py | 2 ++ .../infra/compute_engines/spark/test_nodes.py | 2 ++ .../test_spark_transformation.py | 28 ++++++++++++++++--- 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/sdk/python/feast/infra/compute_engines/spark/utils.py b/sdk/python/feast/infra/compute_engines/spark/utils.py index 4e429f8e075..a03cfdb12d1 100644 --- a/sdk/python/feast/infra/compute_engines/spark/utils.py +++ b/sdk/python/feast/infra/compute_engines/spark/utils.py @@ -21,6 +21,9 @@ def get_or_create_new_spark_session( conf=SparkConf().setAll([(k, v) for k, v in spark_config.items()]) ) + spark_builder = spark_builder.config("spark.driver.host", "127.0.0.1") + spark_builder = spark_builder.config("spark.driver.bindAddress", "127.0.0.1") + spark_session = spark_builder.getOrCreate() spark_session.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true") return spark_session diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py index 7093e40b99e..b723037f1f3 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py @@ -35,6 +35,8 @@ def __init__(self, project_name: str, *args, **kwargs): "spark.eventLog.enabled": "false", "spark.sql.parser.quotedRegexColumnNames": "true", "spark.sql.session.timeZone": "UTC", + "spark.driver.host": "127.0.0.1", + "spark.driver.bindAddress": "127.0.0.1", } if not self.spark_offline_store_config: self.create_offline_store_config() diff --git a/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py index 61824074ae1..0fb2bd4cb78 100644 --- a/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py +++ b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py @@ -25,6 +25,8 @@ def spark_session(): SparkSession.builder.appName("FeastSparkTests") .master("local[*]") .config("spark.sql.shuffle.partitions", "1") + .config("spark.driver.host", "127.0.0.1") + .config("spark.driver.bindAddress", "127.0.0.1") .getOrCreate() ) diff --git a/sdk/python/tests/unit/transformation/test_spark_transformation.py b/sdk/python/tests/unit/transformation/test_spark_transformation.py index 63954faef2f..63d9b520ce9 100644 --- a/sdk/python/tests/unit/transformation/test_spark_transformation.py +++ b/sdk/python/tests/unit/transformation/test_spark_transformation.py @@ -52,7 +52,12 @@ def remove_extra_spaces_sql(df, column_name): @pytest.fixture def spark_fixture(): - spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate() + spark = ( + SparkSession.builder.appName("Testing PySpark Example") + .config("spark.driver.host", "127.0.0.1") + .config("spark.driver.bindAddress", "127.0.0.1") + .getOrCreate() + ) try: yield spark finally: @@ -61,7 +66,12 @@ def spark_fixture(): @patch("feast.infra.compute_engines.spark.utils.get_or_create_new_spark_session") def test_spark_transformation(spark_fixture): - spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate() + spark = ( + SparkSession.builder.appName("Testing PySpark Example") + .config("spark.driver.host", "127.0.0.1") + .config("spark.driver.bindAddress", "127.0.0.1") + .getOrCreate() + ) df = get_sample_df(spark) spark_transformation = Transformation( @@ -77,7 +87,12 @@ def test_spark_transformation(spark_fixture): @patch("feast.infra.compute_engines.spark.utils.get_or_create_new_spark_session") def test_spark_transformation_init_transformation(spark_fixture): - spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate() + spark = ( + SparkSession.builder.appName("Testing PySpark Example") + .config("spark.driver.host", "127.0.0.1") + .config("spark.driver.bindAddress", "127.0.0.1") + .getOrCreate() + ) df = get_sample_df(spark) spark_transformation = SparkTransformation( @@ -93,7 +108,12 @@ def test_spark_transformation_init_transformation(spark_fixture): @patch("feast.infra.compute_engines.spark.utils.get_or_create_new_spark_session") def test_spark_transformation_sql(spark_fixture): - spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate() + spark = ( + SparkSession.builder.appName("Testing PySpark Example") + .config("spark.driver.host", "127.0.0.1") + .config("spark.driver.bindAddress", "127.0.0.1") + .getOrCreate() + ) df = get_sample_df(spark) spark_transformation = SparkTransformation(