diff --git a/sdk/python/feast/api/registry/rest/data_sources.py b/sdk/python/feast/api/registry/rest/data_sources.py index 4927999a951..33200d2a534 100644 --- a/sdk/python/feast/api/registry/rest/data_sources.py +++ b/sdk/python/feast/api/registry/rest/data_sources.py @@ -44,7 +44,7 @@ def list_data_sources( data_sources = response.get("dataSources", []) result = { - "data_sources": data_sources, + "dataSources": data_sources, "pagination": response.get("pagination", {}), } diff --git a/sdk/python/feast/api/registry/rest/entities.py b/sdk/python/feast/api/registry/rest/entities.py index 318fdb5822b..3d39dfbf0f3 100644 --- a/sdk/python/feast/api/registry/rest/entities.py +++ b/sdk/python/feast/api/registry/rest/entities.py @@ -95,10 +95,28 @@ def get_entity( result = entity + relationships = get_object_relationships( + grpc_handler, "entity", name, project, allow_cache + ) + ds_list_req = RegistryServer_pb2.ListDataSourcesRequest( + project=project, + allow_cache=allow_cache, + ) + ds_list_resp = grpc_call(grpc_handler.ListDataSources, ds_list_req) + ds_map = {ds["name"]: ds for ds in ds_list_resp.get("dataSources", [])} + data_source_objs = [] + seen_ds_names = set() + for rel in relationships: + if rel.get("target", {}).get("type") == "dataSource": + ds_name = rel["target"]["name"] + if ds_name not in seen_ds_names: + ds_obj = ds_map.get(ds_name) + if ds_obj: + data_source_objs.append(ds_obj) + seen_ds_names.add(ds_name) + result["dataSources"] = data_source_objs + if include_relationships: - relationships = get_object_relationships( - grpc_handler, "entity", name, project, allow_cache - ) result["relationships"] = relationships return result diff --git a/sdk/python/feast/api/registry/rest/features.py b/sdk/python/feast/api/registry/rest/features.py index 74a9c4bd5f3..7e6296a9858 100644 --- a/sdk/python/feast/api/registry/rest/features.py +++ b/sdk/python/feast/api/registry/rest/features.py @@ -39,9 +39,15 @@ def list_features( sorting=create_grpc_sorting_params(sorting_params), ) response = grpc_call(grpc_handler.ListFeatures, req) + if "features" not in response: + response["features"] = [] + if "pagination" not in response: + response["pagination"] = {} + if include_relationships: + features = response.get("features", []) relationships = get_relationships_for_objects( - grpc_handler, response["features"], "feature", project, allow_cache + grpc_handler, features, "feature", project, allow_cache ) response["relationships"] = relationships return response diff --git a/sdk/python/tests/unit/api/test_api_rest_registry.py b/sdk/python/tests/unit/api/test_api_rest_registry.py index 2748994ba4e..bc893861938 100644 --- a/sdk/python/tests/unit/api/test_api_rest_registry.py +++ b/sdk/python/tests/unit/api/test_api_rest_registry.py @@ -145,8 +145,7 @@ def test_feature_services_via_rest(fastapi_test_app): def test_data_sources_via_rest(fastapi_test_app): response = fastapi_test_app.get("/data_sources?project=demo_project") - assert response.status_code == 200 - assert "data_sources" in response.json() + assert "dataSources" in response.json() response = fastapi_test_app.get( "/data_sources/user_profile_source?project=demo_project" ) @@ -650,9 +649,9 @@ def test_data_sources_pagination_via_rest(fastapi_test_app_with_multiple_objects response = client.get("/data_sources?project=demo_project&page=1&limit=2") assert response.status_code == 200 data = response.json() - assert "data_sources" in data + assert "dataSources" in data assert "pagination" in data - assert len(data["data_sources"]) == 2 + assert len(data["dataSources"]) == 2 assert data["pagination"]["page"] == 1 assert data["pagination"]["limit"] == 2 assert data["pagination"]["totalCount"] == 3 @@ -669,7 +668,7 @@ def test_data_sources_sorting_via_rest(fastapi_test_app_with_multiple_objects): ) assert response.status_code == 200 data = response.json() - ds_names = [ds["name"] for ds in data["data_sources"]] + ds_names = [ds["name"] for ds in data["dataSources"]] assert ds_names == sorted(ds_names) @@ -1064,3 +1063,70 @@ def test_lineage_complete_all_via_rest(fastapi_test_app): assert "dataSources" in project_data["objects"] assert "featureViews" in project_data["objects"] assert "featureServices" in project_data["objects"] + + +def test_invalid_project_name_with_relationships_via_rest(fastapi_test_app): + """Test REST API response with invalid project name using include_relationships=true. + The API should not throw 500 or any other error when an invalid project name is provided + with include_relationships=true parameter. + """ + response = fastapi_test_app.get( + "/entities?project=invalid_project_name&include_relationships=true" + ) + assert response.status_code == 200 + data = response.json() + assert "entities" in data + assert isinstance(data["entities"], list) + assert len(data["entities"]) == 0 + assert "relationships" in data + assert isinstance(data["relationships"], dict) + assert len(data["relationships"]) == 0 + + response = fastapi_test_app.get( + "/feature_views?project=invalid_project_name&include_relationships=true" + ) + assert response.status_code == 200 + data = response.json() + assert "featureViews" in data + assert isinstance(data["featureViews"], list) + assert len(data["featureViews"]) == 0 + assert "relationships" in data + assert isinstance(data["relationships"], dict) + assert len(data["relationships"]) == 0 + + response = fastapi_test_app.get( + "/data_sources?project=invalid_project_name&include_relationships=true" + ) + # Should return 200 with empty results, not 500 or other errors + assert response.status_code == 200 + data = response.json() + assert "dataSources" in data + assert isinstance(data["dataSources"], list) + assert len(data["dataSources"]) == 0 + assert "relationships" in data + assert isinstance(data["relationships"], dict) + assert len(data["relationships"]) == 0 + + response = fastapi_test_app.get( + "/feature_services?project=invalid_project_name&include_relationships=true" + ) + assert response.status_code == 200 + data = response.json() + assert "featureServices" in data + assert isinstance(data["featureServices"], list) + assert len(data["featureServices"]) == 0 + assert "relationships" in data + assert isinstance(data["relationships"], dict) + assert len(data["relationships"]) == 0 + + response = fastapi_test_app.get( + "/features?project=invalid_project_name&include_relationships=true" + ) + assert response.status_code == 200 + data = response.json() + assert "features" in data + assert isinstance(data["features"], list) + assert len(data["features"]) == 0 + assert "relationships" in data + assert isinstance(data["relationships"], dict) + assert len(data["relationships"]) == 0