From f1e78e9193c0e5a69fa298f8cb4297be1518d490 Mon Sep 17 00:00:00 2001 From: ntkathole Date: Sat, 14 Jun 2025 21:16:41 +0530 Subject: [PATCH 01/10] feat: Added ray to requirements Signed-off-by: ntkathole --- pyproject.toml | 5 +- .../requirements/py3.10-ci-requirements.txt | 100 +++++++++++++++++- .../requirements/py3.11-ci-requirements.txt | 100 +++++++++++++++++- .../requirements/py3.12-ci-requirements.txt | 100 +++++++++++++++++- setup.py | 5 + 5 files changed, 299 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1ee36c3a102..2879bfbba26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,7 @@ rag = [ "transformers>=4.36.0", "datasets>=3.6.0", ] +ray = ["ray>=2.47.0"] redis = [ "redis>=4.2.2,<5", "hiredis>=2.0.0,<3", @@ -167,9 +168,9 @@ ci = [ "types-setuptools", "types-tabulate", "virtualenv<20.24.2", - "feast[aws, azure, cassandra, clickhouse, couchbase, delta, docling, duckdb, elasticsearch, faiss, gcp, ge, go, grpcio, hazelcast, hbase, ibis, ikv, k8s, mcp, milvus, mssql, mysql, opentelemetry, spark, trino, postgres, pytorch, qdrant, rag, redis, singlestore, snowflake, sqlite_vec]" + "feast[aws, azure, cassandra, clickhouse, couchbase, delta, docling, duckdb, elasticsearch, faiss, gcp, ge, go, grpcio, hazelcast, hbase, ibis, ikv, k8s, mcp, milvus, mssql, mysql, opentelemetry, spark, trino, postgres, pytorch, qdrant, rag, ray, redis, singlestore, snowflake, sqlite_vec]" ] -nlp = ["feast[docling, milvus, pytorch, rag]"] +nlp = ["feast[docling, milvus, pytorch, rag, ray]"] dev = ["feast[ci]"] docs = ["feast[ci]"] # used for the 'feature-server' container image build diff --git a/sdk/python/requirements/py3.10-ci-requirements.txt b/sdk/python/requirements/py3.10-ci-requirements.txt index 970153b304a..499feed262d 100644 --- a/sdk/python/requirements/py3.10-ci-requirements.txt +++ b/sdk/python/requirements/py3.10-ci-requirements.txt @@ -502,6 +502,7 @@ click==8.2.1 \ # geomet # great-expectations # pip-tools + # ray # typer # uvicorn clickhouse-connect==0.8.18 \ @@ -1062,6 +1063,7 @@ filelock==3.18.0 \ # via # datasets # huggingface-hub + # ray # snowflake-connector-python # torch # transformers @@ -1652,9 +1654,9 @@ httpx-sse==0.4.1 \ --hash=sha256:8f44d34414bc7b21bf3602713005c5df4917884f76072479b21f68befa4ea26e \ --hash=sha256:cba42174344c3a5b06f255ce65b350880f962d99ead85e776f23c6618a377a37 # via mcp -huggingface-hub==0.34.3 \ - --hash=sha256:5444550099e2d86e68b2898b09e85878fbd788fc2957b506c6a79ce060e39492 \ - --hash=sha256:d58130fd5aa7408480681475491c0abd7e835442082fbc3ef4d45b6c39f83853 +huggingface-hub==0.34.4 \ + --hash=sha256:9b365d781739c93ff90c359844221beef048403f1bc1f1c123c191257c3c890a \ + --hash=sha256:a4228daa6fb001be3f4f4bdaf9a0db00e1739235702848df00885c9b5742c85c # via # accelerate # datasets @@ -1795,6 +1797,7 @@ jsonschema[format-nongpl]==4.25.0 \ # jupyterlab-server # mcp # nbformat + # ray jsonschema-specifications==2025.4.1 \ --hash=sha256:4653bffbd6584f7de83a67e0d620ef16900b390ddc7939d56684d6c81e33f1af \ --hash=sha256:630159c9f4dbea161a6a2205c3011cc4f18ff381b189fff48bb39b9bf26ae608 @@ -2327,6 +2330,67 @@ msal-extensions==1.3.1 \ --hash=sha256:96d3de4d034504e969ac5e85bae8106c8373b5c6568e4c8fa7af2eca9dbe6bca \ --hash=sha256:c5b0fd10f65ef62b5f1d62f4251d51cbcaf003fcedae8c91b040a488614be1a4 # via azure-identity +msgpack==1.1.1 \ + --hash=sha256:196a736f0526a03653d829d7d4c5500a97eea3648aebfd4b6743875f28aa2af8 \ + --hash=sha256:1abfc6e949b352dadf4bce0eb78023212ec5ac42f6abfd469ce91d783c149c2a \ + --hash=sha256:1b13fe0fb4aac1aa5320cd693b297fe6fdef0e7bea5518cbc2dd5299f873ae90 \ + --hash=sha256:1d75f3807a9900a7d575d8d6674a3a47e9f227e8716256f35bc6f03fc597ffbf \ + --hash=sha256:2fbbc0b906a24038c9958a1ba7ae0918ad35b06cb449d398b76a7d08470b0ed9 \ + --hash=sha256:33be9ab121df9b6b461ff91baac6f2731f83d9b27ed948c5b9d1978ae28bf157 \ + --hash=sha256:353b6fc0c36fde68b661a12949d7d49f8f51ff5fa019c1e47c87c4ff34b080ed \ + --hash=sha256:36043272c6aede309d29d56851f8841ba907a1a3d04435e43e8a19928e243c1d \ + --hash=sha256:3765afa6bd4832fc11c3749be4ba4b69a0e8d7b728f78e68120a157a4c5d41f0 \ + --hash=sha256:3a89cd8c087ea67e64844287ea52888239cbd2940884eafd2dcd25754fb72232 \ + --hash=sha256:40eae974c873b2992fd36424a5d9407f93e97656d999f43fca9d29f820899084 \ + --hash=sha256:4147151acabb9caed4e474c3344181e91ff7a388b888f1e19ea04f7e73dc7ad5 \ + --hash=sha256:435807eeb1bc791ceb3247d13c79868deb22184e1fc4224808750f0d7d1affc1 \ + --hash=sha256:4835d17af722609a45e16037bb1d4d78b7bdf19d6c0128116d178956618c4e88 \ + --hash=sha256:4a28e8072ae9779f20427af07f53bbb8b4aa81151054e882aee333b158da8752 \ + --hash=sha256:4d3237b224b930d58e9d83c81c0dba7aacc20fcc2f89c1e5423aa0529a4cd142 \ + --hash=sha256:4df2311b0ce24f06ba253fda361f938dfecd7b961576f9be3f3fbd60e87130ac \ + --hash=sha256:4fd6b577e4541676e0cc9ddc1709d25014d3ad9a66caa19962c4f5de30fc09ef \ + --hash=sha256:500e85823a27d6d9bba1d057c871b4210c1dd6fb01fbb764e37e4e8847376323 \ + --hash=sha256:5692095123007180dca3e788bb4c399cc26626da51629a31d40207cb262e67f4 \ + --hash=sha256:5fd1b58e1431008a57247d6e7cc4faa41c3607e8e7d4aaf81f7c29ea013cb458 \ + --hash=sha256:61abccf9de335d9efd149e2fff97ed5974f2481b3353772e8e2dd3402ba2bd57 \ + --hash=sha256:61e35a55a546a1690d9d09effaa436c25ae6130573b6ee9829c37ef0f18d5e78 \ + --hash=sha256:6640fd979ca9a212e4bcdf6eb74051ade2c690b862b679bfcb60ae46e6dc4bfd \ + --hash=sha256:6d489fba546295983abd142812bda76b57e33d0b9f5d5b71c09a583285506f69 \ + --hash=sha256:6f64ae8fe7ffba251fecb8408540c34ee9df1c26674c50c4544d72dbf792e5ce \ + --hash=sha256:71ef05c1726884e44f8b1d1773604ab5d4d17729d8491403a705e649116c9558 \ + --hash=sha256:77b79ce34a2bdab2594f490c8e80dd62a02d650b91a75159a63ec413b8d104cd \ + --hash=sha256:78426096939c2c7482bf31ef15ca219a9e24460289c00dd0b94411040bb73ad2 \ + --hash=sha256:79c408fcf76a958491b4e3b103d1c417044544b68e96d06432a189b43d1215c8 \ + --hash=sha256:7a17ac1ea6ec3c7687d70201cfda3b1e8061466f28f686c24f627cae4ea8efd0 \ + --hash=sha256:7da8831f9a0fdb526621ba09a281fadc58ea12701bc709e7b8cbc362feabc295 \ + --hash=sha256:870b9a626280c86cff9c576ec0d9cbcc54a1e5ebda9cd26dab12baf41fee218c \ + --hash=sha256:88d1e966c9235c1d4e2afac21ca83933ba59537e2e2727a999bf3f515ca2af26 \ + --hash=sha256:88daaf7d146e48ec71212ce21109b66e06a98e5e44dca47d853cbfe171d6c8d2 \ + --hash=sha256:8a8b10fdb84a43e50d38057b06901ec9da52baac6983d3f709d8507f3889d43f \ + --hash=sha256:8b17ba27727a36cb73aabacaa44b13090feb88a01d012c0f4be70c00f75048b4 \ + --hash=sha256:8b65b53204fe1bd037c40c4148d00ef918eb2108d24c9aaa20bc31f9810ce0a8 \ + --hash=sha256:8ddb2bcfd1a8b9e431c8d6f4f7db0773084e107730ecf3472f1dfe9ad583f3d9 \ + --hash=sha256:96decdfc4adcbc087f5ea7ebdcfd3dee9a13358cae6e81d54be962efc38f6338 \ + --hash=sha256:996f2609ddf0142daba4cefd767d6db26958aac8439ee41db9cc0db9f4c4c3a6 \ + --hash=sha256:9d592d06e3cc2f537ceeeb23d38799c6ad83255289bb84c2e5792e5a8dea268a \ + --hash=sha256:a32747b1b39c3ac27d0670122b57e6e57f28eefb725e0b625618d1b59bf9d1e0 \ + --hash=sha256:a494554874691720ba5891c9b0b39474ba43ffb1aaf32a5dac874effb1619e1a \ + --hash=sha256:a8ef6e342c137888ebbfb233e02b8fbd689bb5b5fcc59b34711ac47ebd504478 \ + --hash=sha256:ae497b11f4c21558d95de9f64fff7053544f4d1a17731c866143ed6bb4591238 \ + --hash=sha256:b1ce7f41670c5a69e1389420436f41385b1aa2504c3b0c30620764b15dded2e7 \ + --hash=sha256:b8f93dcddb243159c9e4109c9750ba5b335ab8d48d9522c5308cd05d7e3ce600 \ + --hash=sha256:ba0c325c3f485dc54ec298d8b024e134acf07c10d494ffa24373bea729acf704 \ + --hash=sha256:bb29aaa613c0a1c40d1af111abf025f1732cab333f96f285d6a93b934738a68a \ + --hash=sha256:bba1be28247e68994355e028dcd668316db30c1f758d3241a7b903ac78dcd285 \ + --hash=sha256:cb643284ab0ed26f6957d969fe0dd8bb17beb567beb8998140b5e38a90974f6c \ + --hash=sha256:d182dac0221eb8faef2e6f44701812b467c02674a322c739355c39e94730cdbf \ + --hash=sha256:d275a9e3c81b1093c060c3837e580c37f47c51eca031f7b5fb76f7b8470f5f9b \ + --hash=sha256:d8b55ea20dc59b181d3f47103f113e6f28a5e1c89fd5b67b9140edb442ab67f2 \ + --hash=sha256:da8f41e602574ece93dbbda1fab24650d6bf2a24089f9e9dbb4f5730ec1e58ad \ + --hash=sha256:e4141c5a32b5e37905b5940aacbc59739f036930367d7acce7a64e4dec1f5e0b \ + --hash=sha256:f5be6b6bc52fad84d010cb45433720327ce886009d862f46b26d4d154001994b \ + --hash=sha256:f6d58656842e1b2ddbe07f43f56b10a60f2ba5826164910968f5933e5178af75 + # via ray multidict==6.6.3 \ --hash=sha256:02fd8f32d403a6ff13864b0851f1f523d4c988051eea0471d4f1fd8010f11134 \ --hash=sha256:04cbcce84f63b9af41bad04a54d4cc4e60e90c35b9e6ccb130be2d75b71f8c17 \ @@ -2688,6 +2752,7 @@ packaging==24.2 \ # nbconvert # pandas-gbq # pytest + # ray # scikit-image # snowflake-connector-python # sphinx @@ -3090,6 +3155,7 @@ protobuf==4.25.8 \ # proto-plus # pymilvus # qdrant-client + # ray # substrait psutil==5.9.0 \ --hash=sha256:072664401ae6e7c1bfb878c65d7282d4b4391f1bc9a56d5e03b5a490403271b5 \ @@ -3962,6 +4028,7 @@ pyyaml==6.0.2 \ # jupyter-events # kubernetes # pre-commit + # ray # responses # transformers # uvicorn @@ -4066,6 +4133,32 @@ qdrant-client==1.15.1 \ --hash=sha256:2b975099b378382f6ca1cfb43f0d59e541be6e16a5892f282a4b8de7eff5cb63 \ --hash=sha256:631f1f3caebfad0fd0c1fba98f41be81d9962b7bf3ca653bed3b727c0e0cbe0e # via feast (setup.py) +ray==2.48.0 \ + --hash=sha256:24a70f416ec0be14b975f160044805ccb48cc6bc50de632983eb8f0a8e16682b \ + --hash=sha256:25e4b79fcc8f849d72db1acc4f03f37008c5c0b745df63d8a30cd35676b6545e \ + --hash=sha256:33bda4753ad0acd2b524c9158089d43486cd44cc59fe970466435bc2968fde2d \ + --hash=sha256:46d4b42a58492dec79caad2d562344689a4f99a828aeea811a0cd2cd653553ef \ + --hash=sha256:4b9b92ac29635f555ef341347d9a63dbf02b7d946347239af3c09e364bc45cf8 \ + --hash=sha256:5742b72a514afe5d60f41330200cd508376e16c650f6962e62337aa482d6a0c6 \ + --hash=sha256:5a6f57126eac9dd3286289e07e91e87b054792f9698b6f7ccab88b624816b542 \ + --hash=sha256:622e6bcdb78d98040d87bea94e65d0bb6ccc0ae1b43294c6bd69f542bf28e092 \ + --hash=sha256:649ed9442dc2d39135c593b6cf0c38e8355170b92672365ab7a3cbc958c42634 \ + --hash=sha256:6ca2b9ce45ad360cbe2996982fb22691ecfe6553ec8f97a2548295f0f96aac78 \ + --hash=sha256:8de799f3b0896f48d306d5e4a04fc6037a08c495d45f9c79935344e5693e3cf8 \ + --hash=sha256:a42ed3b640f4b599a3fc8067c83ee60497c0f03d070d7a7df02a388fa17a546b \ + --hash=sha256:a45de103173c2ed6a0defd7a2919a2bbe531fd5bf6619860cd111ca4a16e9288 \ + --hash=sha256:a7a6d830d9dc5ae8bb156fcde9a1adab7f4edb004f03918a724d885eceb8264d \ + --hash=sha256:b37a0fea4094f95d5926b1d7245abd70deb62882da3d738f9f9b76214894745c \ + --hash=sha256:b427dead5f8ad96d494d3a006d92ea2f8f16be5e6303b590e12234b37f96fbc2 \ + --hash=sha256:b94500fe2d17e491fe2e9bd4a3bf62df217e21a8f2845033c353d4d2ea240f73 \ + --hash=sha256:be45690565907c4aa035d753d82f6ff892d1e6830057b67399542a035b3682f0 \ + --hash=sha256:cfb48c10371c267fdcf7f4ae359cab706f068178b9c65317ead011972f2c0bf3 \ + --hash=sha256:e15fdffa6b60d5729f6025691396b8a01dc3461ba19dc92bba354ec1813ed6b1 \ + --hash=sha256:e6543fb3450a71862cfa1e7a666025d751f81685602cc87d499072ccd839507d \ + --hash=sha256:ea9d7739ae8f6db48b226bbc2a592640f7f2b6d854ff73d0305774b98fa9fb11 \ + --hash=sha256:f1cf33d260316f92f77558185f1c36fc35506d76ee7fdfed9f5b70f9c4bdba7f \ + --hash=sha256:f820950bc44d7b000c223342f5c800c9c08e7fd89524201125388ea211caad1a + # via feast (setup.py) redis==4.6.0 \ --hash=sha256:585dc516b9eb042a619ef0a39c3d7d55fe81bdb4df09a52c9cdde0d07bf1aa7d \ --hash=sha256:e2b03db868160ee4591de3cb90d40ebb50a90dd302138775937f6a42b7ed183c @@ -4189,6 +4282,7 @@ requests==2.32.4 \ # moto # msal # python-keycloak + # ray # requests-oauthlib # requests-toolbelt # responses diff --git a/sdk/python/requirements/py3.11-ci-requirements.txt b/sdk/python/requirements/py3.11-ci-requirements.txt index a44b6551fbf..a14ea9657a9 100644 --- a/sdk/python/requirements/py3.11-ci-requirements.txt +++ b/sdk/python/requirements/py3.11-ci-requirements.txt @@ -500,6 +500,7 @@ click==8.2.1 \ # geomet # great-expectations # pip-tools + # ray # typer # uvicorn clickhouse-connect==0.8.18 \ @@ -1053,6 +1054,7 @@ filelock==3.18.0 \ # via # datasets # huggingface-hub + # ray # snowflake-connector-python # torch # transformers @@ -1643,9 +1645,9 @@ httpx-sse==0.4.1 \ --hash=sha256:8f44d34414bc7b21bf3602713005c5df4917884f76072479b21f68befa4ea26e \ --hash=sha256:cba42174344c3a5b06f255ce65b350880f962d99ead85e776f23c6618a377a37 # via mcp -huggingface-hub==0.34.3 \ - --hash=sha256:5444550099e2d86e68b2898b09e85878fbd788fc2957b506c6a79ce060e39492 \ - --hash=sha256:d58130fd5aa7408480681475491c0abd7e835442082fbc3ef4d45b6c39f83853 +huggingface-hub==0.34.4 \ + --hash=sha256:9b365d781739c93ff90c359844221beef048403f1bc1f1c123c191257c3c890a \ + --hash=sha256:a4228daa6fb001be3f4f4bdaf9a0db00e1739235702848df00885c9b5742c85c # via # accelerate # datasets @@ -1788,6 +1790,7 @@ jsonschema[format-nongpl]==4.25.0 \ # jupyterlab-server # mcp # nbformat + # ray jsonschema-specifications==2025.4.1 \ --hash=sha256:4653bffbd6584f7de83a67e0d620ef16900b390ddc7939d56684d6c81e33f1af \ --hash=sha256:630159c9f4dbea161a6a2205c3011cc4f18ff381b189fff48bb39b9bf26ae608 @@ -2320,6 +2323,67 @@ msal-extensions==1.3.1 \ --hash=sha256:96d3de4d034504e969ac5e85bae8106c8373b5c6568e4c8fa7af2eca9dbe6bca \ --hash=sha256:c5b0fd10f65ef62b5f1d62f4251d51cbcaf003fcedae8c91b040a488614be1a4 # via azure-identity +msgpack==1.1.1 \ + --hash=sha256:196a736f0526a03653d829d7d4c5500a97eea3648aebfd4b6743875f28aa2af8 \ + --hash=sha256:1abfc6e949b352dadf4bce0eb78023212ec5ac42f6abfd469ce91d783c149c2a \ + --hash=sha256:1b13fe0fb4aac1aa5320cd693b297fe6fdef0e7bea5518cbc2dd5299f873ae90 \ + --hash=sha256:1d75f3807a9900a7d575d8d6674a3a47e9f227e8716256f35bc6f03fc597ffbf \ + --hash=sha256:2fbbc0b906a24038c9958a1ba7ae0918ad35b06cb449d398b76a7d08470b0ed9 \ + --hash=sha256:33be9ab121df9b6b461ff91baac6f2731f83d9b27ed948c5b9d1978ae28bf157 \ + --hash=sha256:353b6fc0c36fde68b661a12949d7d49f8f51ff5fa019c1e47c87c4ff34b080ed \ + --hash=sha256:36043272c6aede309d29d56851f8841ba907a1a3d04435e43e8a19928e243c1d \ + --hash=sha256:3765afa6bd4832fc11c3749be4ba4b69a0e8d7b728f78e68120a157a4c5d41f0 \ + --hash=sha256:3a89cd8c087ea67e64844287ea52888239cbd2940884eafd2dcd25754fb72232 \ + --hash=sha256:40eae974c873b2992fd36424a5d9407f93e97656d999f43fca9d29f820899084 \ + --hash=sha256:4147151acabb9caed4e474c3344181e91ff7a388b888f1e19ea04f7e73dc7ad5 \ + --hash=sha256:435807eeb1bc791ceb3247d13c79868deb22184e1fc4224808750f0d7d1affc1 \ + --hash=sha256:4835d17af722609a45e16037bb1d4d78b7bdf19d6c0128116d178956618c4e88 \ + --hash=sha256:4a28e8072ae9779f20427af07f53bbb8b4aa81151054e882aee333b158da8752 \ + --hash=sha256:4d3237b224b930d58e9d83c81c0dba7aacc20fcc2f89c1e5423aa0529a4cd142 \ + --hash=sha256:4df2311b0ce24f06ba253fda361f938dfecd7b961576f9be3f3fbd60e87130ac \ + --hash=sha256:4fd6b577e4541676e0cc9ddc1709d25014d3ad9a66caa19962c4f5de30fc09ef \ + --hash=sha256:500e85823a27d6d9bba1d057c871b4210c1dd6fb01fbb764e37e4e8847376323 \ + --hash=sha256:5692095123007180dca3e788bb4c399cc26626da51629a31d40207cb262e67f4 \ + --hash=sha256:5fd1b58e1431008a57247d6e7cc4faa41c3607e8e7d4aaf81f7c29ea013cb458 \ + --hash=sha256:61abccf9de335d9efd149e2fff97ed5974f2481b3353772e8e2dd3402ba2bd57 \ + --hash=sha256:61e35a55a546a1690d9d09effaa436c25ae6130573b6ee9829c37ef0f18d5e78 \ + --hash=sha256:6640fd979ca9a212e4bcdf6eb74051ade2c690b862b679bfcb60ae46e6dc4bfd \ + --hash=sha256:6d489fba546295983abd142812bda76b57e33d0b9f5d5b71c09a583285506f69 \ + --hash=sha256:6f64ae8fe7ffba251fecb8408540c34ee9df1c26674c50c4544d72dbf792e5ce \ + --hash=sha256:71ef05c1726884e44f8b1d1773604ab5d4d17729d8491403a705e649116c9558 \ + --hash=sha256:77b79ce34a2bdab2594f490c8e80dd62a02d650b91a75159a63ec413b8d104cd \ + --hash=sha256:78426096939c2c7482bf31ef15ca219a9e24460289c00dd0b94411040bb73ad2 \ + --hash=sha256:79c408fcf76a958491b4e3b103d1c417044544b68e96d06432a189b43d1215c8 \ + --hash=sha256:7a17ac1ea6ec3c7687d70201cfda3b1e8061466f28f686c24f627cae4ea8efd0 \ + --hash=sha256:7da8831f9a0fdb526621ba09a281fadc58ea12701bc709e7b8cbc362feabc295 \ + --hash=sha256:870b9a626280c86cff9c576ec0d9cbcc54a1e5ebda9cd26dab12baf41fee218c \ + --hash=sha256:88d1e966c9235c1d4e2afac21ca83933ba59537e2e2727a999bf3f515ca2af26 \ + --hash=sha256:88daaf7d146e48ec71212ce21109b66e06a98e5e44dca47d853cbfe171d6c8d2 \ + --hash=sha256:8a8b10fdb84a43e50d38057b06901ec9da52baac6983d3f709d8507f3889d43f \ + --hash=sha256:8b17ba27727a36cb73aabacaa44b13090feb88a01d012c0f4be70c00f75048b4 \ + --hash=sha256:8b65b53204fe1bd037c40c4148d00ef918eb2108d24c9aaa20bc31f9810ce0a8 \ + --hash=sha256:8ddb2bcfd1a8b9e431c8d6f4f7db0773084e107730ecf3472f1dfe9ad583f3d9 \ + --hash=sha256:96decdfc4adcbc087f5ea7ebdcfd3dee9a13358cae6e81d54be962efc38f6338 \ + --hash=sha256:996f2609ddf0142daba4cefd767d6db26958aac8439ee41db9cc0db9f4c4c3a6 \ + --hash=sha256:9d592d06e3cc2f537ceeeb23d38799c6ad83255289bb84c2e5792e5a8dea268a \ + --hash=sha256:a32747b1b39c3ac27d0670122b57e6e57f28eefb725e0b625618d1b59bf9d1e0 \ + --hash=sha256:a494554874691720ba5891c9b0b39474ba43ffb1aaf32a5dac874effb1619e1a \ + --hash=sha256:a8ef6e342c137888ebbfb233e02b8fbd689bb5b5fcc59b34711ac47ebd504478 \ + --hash=sha256:ae497b11f4c21558d95de9f64fff7053544f4d1a17731c866143ed6bb4591238 \ + --hash=sha256:b1ce7f41670c5a69e1389420436f41385b1aa2504c3b0c30620764b15dded2e7 \ + --hash=sha256:b8f93dcddb243159c9e4109c9750ba5b335ab8d48d9522c5308cd05d7e3ce600 \ + --hash=sha256:ba0c325c3f485dc54ec298d8b024e134acf07c10d494ffa24373bea729acf704 \ + --hash=sha256:bb29aaa613c0a1c40d1af111abf025f1732cab333f96f285d6a93b934738a68a \ + --hash=sha256:bba1be28247e68994355e028dcd668316db30c1f758d3241a7b903ac78dcd285 \ + --hash=sha256:cb643284ab0ed26f6957d969fe0dd8bb17beb567beb8998140b5e38a90974f6c \ + --hash=sha256:d182dac0221eb8faef2e6f44701812b467c02674a322c739355c39e94730cdbf \ + --hash=sha256:d275a9e3c81b1093c060c3837e580c37f47c51eca031f7b5fb76f7b8470f5f9b \ + --hash=sha256:d8b55ea20dc59b181d3f47103f113e6f28a5e1c89fd5b67b9140edb442ab67f2 \ + --hash=sha256:da8f41e602574ece93dbbda1fab24650d6bf2a24089f9e9dbb4f5730ec1e58ad \ + --hash=sha256:e4141c5a32b5e37905b5940aacbc59739f036930367d7acce7a64e4dec1f5e0b \ + --hash=sha256:f5be6b6bc52fad84d010cb45433720327ce886009d862f46b26d4d154001994b \ + --hash=sha256:f6d58656842e1b2ddbe07f43f56b10a60f2ba5826164910968f5933e5178af75 + # via ray multidict==6.6.3 \ --hash=sha256:02fd8f32d403a6ff13864b0851f1f523d4c988051eea0471d4f1fd8010f11134 \ --hash=sha256:04cbcce84f63b9af41bad04a54d4cc4e60e90c35b9e6ccb130be2d75b71f8c17 \ @@ -2700,6 +2764,7 @@ packaging==24.2 \ # nbconvert # pandas-gbq # pytest + # ray # scikit-image # snowflake-connector-python # sphinx @@ -3102,6 +3167,7 @@ protobuf==4.25.8 \ # proto-plus # pymilvus # qdrant-client + # ray # substrait psutil==5.9.0 \ --hash=sha256:072664401ae6e7c1bfb878c65d7282d4b4391f1bc9a56d5e03b5a490403271b5 \ @@ -3975,6 +4041,7 @@ pyyaml==6.0.2 \ # jupyter-events # kubernetes # pre-commit + # ray # responses # transformers # uvicorn @@ -4079,6 +4146,32 @@ qdrant-client==1.15.1 \ --hash=sha256:2b975099b378382f6ca1cfb43f0d59e541be6e16a5892f282a4b8de7eff5cb63 \ --hash=sha256:631f1f3caebfad0fd0c1fba98f41be81d9962b7bf3ca653bed3b727c0e0cbe0e # via feast (setup.py) +ray==2.48.0 \ + --hash=sha256:24a70f416ec0be14b975f160044805ccb48cc6bc50de632983eb8f0a8e16682b \ + --hash=sha256:25e4b79fcc8f849d72db1acc4f03f37008c5c0b745df63d8a30cd35676b6545e \ + --hash=sha256:33bda4753ad0acd2b524c9158089d43486cd44cc59fe970466435bc2968fde2d \ + --hash=sha256:46d4b42a58492dec79caad2d562344689a4f99a828aeea811a0cd2cd653553ef \ + --hash=sha256:4b9b92ac29635f555ef341347d9a63dbf02b7d946347239af3c09e364bc45cf8 \ + --hash=sha256:5742b72a514afe5d60f41330200cd508376e16c650f6962e62337aa482d6a0c6 \ + --hash=sha256:5a6f57126eac9dd3286289e07e91e87b054792f9698b6f7ccab88b624816b542 \ + --hash=sha256:622e6bcdb78d98040d87bea94e65d0bb6ccc0ae1b43294c6bd69f542bf28e092 \ + --hash=sha256:649ed9442dc2d39135c593b6cf0c38e8355170b92672365ab7a3cbc958c42634 \ + --hash=sha256:6ca2b9ce45ad360cbe2996982fb22691ecfe6553ec8f97a2548295f0f96aac78 \ + --hash=sha256:8de799f3b0896f48d306d5e4a04fc6037a08c495d45f9c79935344e5693e3cf8 \ + --hash=sha256:a42ed3b640f4b599a3fc8067c83ee60497c0f03d070d7a7df02a388fa17a546b \ + --hash=sha256:a45de103173c2ed6a0defd7a2919a2bbe531fd5bf6619860cd111ca4a16e9288 \ + --hash=sha256:a7a6d830d9dc5ae8bb156fcde9a1adab7f4edb004f03918a724d885eceb8264d \ + --hash=sha256:b37a0fea4094f95d5926b1d7245abd70deb62882da3d738f9f9b76214894745c \ + --hash=sha256:b427dead5f8ad96d494d3a006d92ea2f8f16be5e6303b590e12234b37f96fbc2 \ + --hash=sha256:b94500fe2d17e491fe2e9bd4a3bf62df217e21a8f2845033c353d4d2ea240f73 \ + --hash=sha256:be45690565907c4aa035d753d82f6ff892d1e6830057b67399542a035b3682f0 \ + --hash=sha256:cfb48c10371c267fdcf7f4ae359cab706f068178b9c65317ead011972f2c0bf3 \ + --hash=sha256:e15fdffa6b60d5729f6025691396b8a01dc3461ba19dc92bba354ec1813ed6b1 \ + --hash=sha256:e6543fb3450a71862cfa1e7a666025d751f81685602cc87d499072ccd839507d \ + --hash=sha256:ea9d7739ae8f6db48b226bbc2a592640f7f2b6d854ff73d0305774b98fa9fb11 \ + --hash=sha256:f1cf33d260316f92f77558185f1c36fc35506d76ee7fdfed9f5b70f9c4bdba7f \ + --hash=sha256:f820950bc44d7b000c223342f5c800c9c08e7fd89524201125388ea211caad1a + # via feast (setup.py) redis==4.6.0 \ --hash=sha256:585dc516b9eb042a619ef0a39c3d7d55fe81bdb4df09a52c9cdde0d07bf1aa7d \ --hash=sha256:e2b03db868160ee4591de3cb90d40ebb50a90dd302138775937f6a42b7ed183c @@ -4202,6 +4295,7 @@ requests==2.32.4 \ # moto # msal # python-keycloak + # ray # requests-oauthlib # requests-toolbelt # responses diff --git a/sdk/python/requirements/py3.12-ci-requirements.txt b/sdk/python/requirements/py3.12-ci-requirements.txt index c1dcdeadf07..5f6cb40390f 100644 --- a/sdk/python/requirements/py3.12-ci-requirements.txt +++ b/sdk/python/requirements/py3.12-ci-requirements.txt @@ -496,6 +496,7 @@ click==8.2.1 \ # geomet # great-expectations # pip-tools + # ray # typer # uvicorn clickhouse-connect==0.8.18 \ @@ -1049,6 +1050,7 @@ filelock==3.18.0 \ # via # datasets # huggingface-hub + # ray # snowflake-connector-python # torch # transformers @@ -1639,9 +1641,9 @@ httpx-sse==0.4.1 \ --hash=sha256:8f44d34414bc7b21bf3602713005c5df4917884f76072479b21f68befa4ea26e \ --hash=sha256:cba42174344c3a5b06f255ce65b350880f962d99ead85e776f23c6618a377a37 # via mcp -huggingface-hub==0.34.3 \ - --hash=sha256:5444550099e2d86e68b2898b09e85878fbd788fc2957b506c6a79ce060e39492 \ - --hash=sha256:d58130fd5aa7408480681475491c0abd7e835442082fbc3ef4d45b6c39f83853 +huggingface-hub==0.34.4 \ + --hash=sha256:9b365d781739c93ff90c359844221beef048403f1bc1f1c123c191257c3c890a \ + --hash=sha256:a4228daa6fb001be3f4f4bdaf9a0db00e1739235702848df00885c9b5742c85c # via # accelerate # datasets @@ -1780,6 +1782,7 @@ jsonschema[format-nongpl]==4.25.0 \ # jupyterlab-server # mcp # nbformat + # ray jsonschema-specifications==2025.4.1 \ --hash=sha256:4653bffbd6584f7de83a67e0d620ef16900b390ddc7939d56684d6c81e33f1af \ --hash=sha256:630159c9f4dbea161a6a2205c3011cc4f18ff381b189fff48bb39b9bf26ae608 @@ -2312,6 +2315,67 @@ msal-extensions==1.3.1 \ --hash=sha256:96d3de4d034504e969ac5e85bae8106c8373b5c6568e4c8fa7af2eca9dbe6bca \ --hash=sha256:c5b0fd10f65ef62b5f1d62f4251d51cbcaf003fcedae8c91b040a488614be1a4 # via azure-identity +msgpack==1.1.1 \ + --hash=sha256:196a736f0526a03653d829d7d4c5500a97eea3648aebfd4b6743875f28aa2af8 \ + --hash=sha256:1abfc6e949b352dadf4bce0eb78023212ec5ac42f6abfd469ce91d783c149c2a \ + --hash=sha256:1b13fe0fb4aac1aa5320cd693b297fe6fdef0e7bea5518cbc2dd5299f873ae90 \ + --hash=sha256:1d75f3807a9900a7d575d8d6674a3a47e9f227e8716256f35bc6f03fc597ffbf \ + --hash=sha256:2fbbc0b906a24038c9958a1ba7ae0918ad35b06cb449d398b76a7d08470b0ed9 \ + --hash=sha256:33be9ab121df9b6b461ff91baac6f2731f83d9b27ed948c5b9d1978ae28bf157 \ + --hash=sha256:353b6fc0c36fde68b661a12949d7d49f8f51ff5fa019c1e47c87c4ff34b080ed \ + --hash=sha256:36043272c6aede309d29d56851f8841ba907a1a3d04435e43e8a19928e243c1d \ + --hash=sha256:3765afa6bd4832fc11c3749be4ba4b69a0e8d7b728f78e68120a157a4c5d41f0 \ + --hash=sha256:3a89cd8c087ea67e64844287ea52888239cbd2940884eafd2dcd25754fb72232 \ + --hash=sha256:40eae974c873b2992fd36424a5d9407f93e97656d999f43fca9d29f820899084 \ + --hash=sha256:4147151acabb9caed4e474c3344181e91ff7a388b888f1e19ea04f7e73dc7ad5 \ + --hash=sha256:435807eeb1bc791ceb3247d13c79868deb22184e1fc4224808750f0d7d1affc1 \ + --hash=sha256:4835d17af722609a45e16037bb1d4d78b7bdf19d6c0128116d178956618c4e88 \ + --hash=sha256:4a28e8072ae9779f20427af07f53bbb8b4aa81151054e882aee333b158da8752 \ + --hash=sha256:4d3237b224b930d58e9d83c81c0dba7aacc20fcc2f89c1e5423aa0529a4cd142 \ + --hash=sha256:4df2311b0ce24f06ba253fda361f938dfecd7b961576f9be3f3fbd60e87130ac \ + --hash=sha256:4fd6b577e4541676e0cc9ddc1709d25014d3ad9a66caa19962c4f5de30fc09ef \ + --hash=sha256:500e85823a27d6d9bba1d057c871b4210c1dd6fb01fbb764e37e4e8847376323 \ + --hash=sha256:5692095123007180dca3e788bb4c399cc26626da51629a31d40207cb262e67f4 \ + --hash=sha256:5fd1b58e1431008a57247d6e7cc4faa41c3607e8e7d4aaf81f7c29ea013cb458 \ + --hash=sha256:61abccf9de335d9efd149e2fff97ed5974f2481b3353772e8e2dd3402ba2bd57 \ + --hash=sha256:61e35a55a546a1690d9d09effaa436c25ae6130573b6ee9829c37ef0f18d5e78 \ + --hash=sha256:6640fd979ca9a212e4bcdf6eb74051ade2c690b862b679bfcb60ae46e6dc4bfd \ + --hash=sha256:6d489fba546295983abd142812bda76b57e33d0b9f5d5b71c09a583285506f69 \ + --hash=sha256:6f64ae8fe7ffba251fecb8408540c34ee9df1c26674c50c4544d72dbf792e5ce \ + --hash=sha256:71ef05c1726884e44f8b1d1773604ab5d4d17729d8491403a705e649116c9558 \ + --hash=sha256:77b79ce34a2bdab2594f490c8e80dd62a02d650b91a75159a63ec413b8d104cd \ + --hash=sha256:78426096939c2c7482bf31ef15ca219a9e24460289c00dd0b94411040bb73ad2 \ + --hash=sha256:79c408fcf76a958491b4e3b103d1c417044544b68e96d06432a189b43d1215c8 \ + --hash=sha256:7a17ac1ea6ec3c7687d70201cfda3b1e8061466f28f686c24f627cae4ea8efd0 \ + --hash=sha256:7da8831f9a0fdb526621ba09a281fadc58ea12701bc709e7b8cbc362feabc295 \ + --hash=sha256:870b9a626280c86cff9c576ec0d9cbcc54a1e5ebda9cd26dab12baf41fee218c \ + --hash=sha256:88d1e966c9235c1d4e2afac21ca83933ba59537e2e2727a999bf3f515ca2af26 \ + --hash=sha256:88daaf7d146e48ec71212ce21109b66e06a98e5e44dca47d853cbfe171d6c8d2 \ + --hash=sha256:8a8b10fdb84a43e50d38057b06901ec9da52baac6983d3f709d8507f3889d43f \ + --hash=sha256:8b17ba27727a36cb73aabacaa44b13090feb88a01d012c0f4be70c00f75048b4 \ + --hash=sha256:8b65b53204fe1bd037c40c4148d00ef918eb2108d24c9aaa20bc31f9810ce0a8 \ + --hash=sha256:8ddb2bcfd1a8b9e431c8d6f4f7db0773084e107730ecf3472f1dfe9ad583f3d9 \ + --hash=sha256:96decdfc4adcbc087f5ea7ebdcfd3dee9a13358cae6e81d54be962efc38f6338 \ + --hash=sha256:996f2609ddf0142daba4cefd767d6db26958aac8439ee41db9cc0db9f4c4c3a6 \ + --hash=sha256:9d592d06e3cc2f537ceeeb23d38799c6ad83255289bb84c2e5792e5a8dea268a \ + --hash=sha256:a32747b1b39c3ac27d0670122b57e6e57f28eefb725e0b625618d1b59bf9d1e0 \ + --hash=sha256:a494554874691720ba5891c9b0b39474ba43ffb1aaf32a5dac874effb1619e1a \ + --hash=sha256:a8ef6e342c137888ebbfb233e02b8fbd689bb5b5fcc59b34711ac47ebd504478 \ + --hash=sha256:ae497b11f4c21558d95de9f64fff7053544f4d1a17731c866143ed6bb4591238 \ + --hash=sha256:b1ce7f41670c5a69e1389420436f41385b1aa2504c3b0c30620764b15dded2e7 \ + --hash=sha256:b8f93dcddb243159c9e4109c9750ba5b335ab8d48d9522c5308cd05d7e3ce600 \ + --hash=sha256:ba0c325c3f485dc54ec298d8b024e134acf07c10d494ffa24373bea729acf704 \ + --hash=sha256:bb29aaa613c0a1c40d1af111abf025f1732cab333f96f285d6a93b934738a68a \ + --hash=sha256:bba1be28247e68994355e028dcd668316db30c1f758d3241a7b903ac78dcd285 \ + --hash=sha256:cb643284ab0ed26f6957d969fe0dd8bb17beb567beb8998140b5e38a90974f6c \ + --hash=sha256:d182dac0221eb8faef2e6f44701812b467c02674a322c739355c39e94730cdbf \ + --hash=sha256:d275a9e3c81b1093c060c3837e580c37f47c51eca031f7b5fb76f7b8470f5f9b \ + --hash=sha256:d8b55ea20dc59b181d3f47103f113e6f28a5e1c89fd5b67b9140edb442ab67f2 \ + --hash=sha256:da8f41e602574ece93dbbda1fab24650d6bf2a24089f9e9dbb4f5730ec1e58ad \ + --hash=sha256:e4141c5a32b5e37905b5940aacbc59739f036930367d7acce7a64e4dec1f5e0b \ + --hash=sha256:f5be6b6bc52fad84d010cb45433720327ce886009d862f46b26d4d154001994b \ + --hash=sha256:f6d58656842e1b2ddbe07f43f56b10a60f2ba5826164910968f5933e5178af75 + # via ray multidict==6.6.3 \ --hash=sha256:02fd8f32d403a6ff13864b0851f1f523d4c988051eea0471d4f1fd8010f11134 \ --hash=sha256:04cbcce84f63b9af41bad04a54d4cc4e60e90c35b9e6ccb130be2d75b71f8c17 \ @@ -2692,6 +2756,7 @@ packaging==24.2 \ # nbconvert # pandas-gbq # pytest + # ray # scikit-image # snowflake-connector-python # sphinx @@ -3094,6 +3159,7 @@ protobuf==4.25.8 \ # proto-plus # pymilvus # qdrant-client + # ray # substrait psutil==5.9.0 \ --hash=sha256:072664401ae6e7c1bfb878c65d7282d4b4391f1bc9a56d5e03b5a490403271b5 \ @@ -3967,6 +4033,7 @@ pyyaml==6.0.2 \ # jupyter-events # kubernetes # pre-commit + # ray # responses # transformers # uvicorn @@ -4071,6 +4138,32 @@ qdrant-client==1.15.1 \ --hash=sha256:2b975099b378382f6ca1cfb43f0d59e541be6e16a5892f282a4b8de7eff5cb63 \ --hash=sha256:631f1f3caebfad0fd0c1fba98f41be81d9962b7bf3ca653bed3b727c0e0cbe0e # via feast (setup.py) +ray==2.48.0 \ + --hash=sha256:24a70f416ec0be14b975f160044805ccb48cc6bc50de632983eb8f0a8e16682b \ + --hash=sha256:25e4b79fcc8f849d72db1acc4f03f37008c5c0b745df63d8a30cd35676b6545e \ + --hash=sha256:33bda4753ad0acd2b524c9158089d43486cd44cc59fe970466435bc2968fde2d \ + --hash=sha256:46d4b42a58492dec79caad2d562344689a4f99a828aeea811a0cd2cd653553ef \ + --hash=sha256:4b9b92ac29635f555ef341347d9a63dbf02b7d946347239af3c09e364bc45cf8 \ + --hash=sha256:5742b72a514afe5d60f41330200cd508376e16c650f6962e62337aa482d6a0c6 \ + --hash=sha256:5a6f57126eac9dd3286289e07e91e87b054792f9698b6f7ccab88b624816b542 \ + --hash=sha256:622e6bcdb78d98040d87bea94e65d0bb6ccc0ae1b43294c6bd69f542bf28e092 \ + --hash=sha256:649ed9442dc2d39135c593b6cf0c38e8355170b92672365ab7a3cbc958c42634 \ + --hash=sha256:6ca2b9ce45ad360cbe2996982fb22691ecfe6553ec8f97a2548295f0f96aac78 \ + --hash=sha256:8de799f3b0896f48d306d5e4a04fc6037a08c495d45f9c79935344e5693e3cf8 \ + --hash=sha256:a42ed3b640f4b599a3fc8067c83ee60497c0f03d070d7a7df02a388fa17a546b \ + --hash=sha256:a45de103173c2ed6a0defd7a2919a2bbe531fd5bf6619860cd111ca4a16e9288 \ + --hash=sha256:a7a6d830d9dc5ae8bb156fcde9a1adab7f4edb004f03918a724d885eceb8264d \ + --hash=sha256:b37a0fea4094f95d5926b1d7245abd70deb62882da3d738f9f9b76214894745c \ + --hash=sha256:b427dead5f8ad96d494d3a006d92ea2f8f16be5e6303b590e12234b37f96fbc2 \ + --hash=sha256:b94500fe2d17e491fe2e9bd4a3bf62df217e21a8f2845033c353d4d2ea240f73 \ + --hash=sha256:be45690565907c4aa035d753d82f6ff892d1e6830057b67399542a035b3682f0 \ + --hash=sha256:cfb48c10371c267fdcf7f4ae359cab706f068178b9c65317ead011972f2c0bf3 \ + --hash=sha256:e15fdffa6b60d5729f6025691396b8a01dc3461ba19dc92bba354ec1813ed6b1 \ + --hash=sha256:e6543fb3450a71862cfa1e7a666025d751f81685602cc87d499072ccd839507d \ + --hash=sha256:ea9d7739ae8f6db48b226bbc2a592640f7f2b6d854ff73d0305774b98fa9fb11 \ + --hash=sha256:f1cf33d260316f92f77558185f1c36fc35506d76ee7fdfed9f5b70f9c4bdba7f \ + --hash=sha256:f820950bc44d7b000c223342f5c800c9c08e7fd89524201125388ea211caad1a + # via feast (setup.py) redis==4.6.0 \ --hash=sha256:585dc516b9eb042a619ef0a39c3d7d55fe81bdb4df09a52c9cdde0d07bf1aa7d \ --hash=sha256:e2b03db868160ee4591de3cb90d40ebb50a90dd302138775937f6a42b7ed183c @@ -4194,6 +4287,7 @@ requests==2.32.4 \ # moto # msal # python-keycloak + # ray # requests-oauthlib # requests-toolbelt # responses diff --git a/setup.py b/setup.py index 033b2491c02..2dcaea178cb 100644 --- a/setup.py +++ b/setup.py @@ -180,6 +180,8 @@ "datasets>=3.6.0", ] +RAY_REQUIRED = ["ray>=2.47.0"] + CI_REQUIRED = ( [ "build", @@ -256,6 +258,7 @@ + CLICKHOUSE_REQUIRED + MCP_REQUIRED + RAG_REQUIRED + + RAY_REQUIRED ) MINIMAL_REQUIRED = ( GCP_REQUIRED @@ -276,6 +279,7 @@ + MILVUS_REQUIRED + TORCH_REQUIRED + RAG_REQUIRED + + RAY_REQUIRED ) DOCS_REQUIRED = CI_REQUIRED DEV_REQUIRED = CI_REQUIRED @@ -358,6 +362,7 @@ "clickhouse": CLICKHOUSE_REQUIRED, "mcp": MCP_REQUIRED, "rag": RAG_REQUIRED, + "ray": RAY_REQUIRED, }, include_package_data=True, license="Apache", From bd27d23008fd2aa4830c415cb47be3ed2c94c7ec Mon Sep 17 00:00:00 2001 From: ntkathole Date: Tue, 17 Jun 2025 21:29:56 +0530 Subject: [PATCH 02/10] feat: Added Ray offline store Signed-off-by: ntkathole --- Makefile | 22 + docs/reference/offline-stores/ray.md | 227 +++++++ .../contrib/ray_offline_store/__init__.py | 0 .../contrib/ray_offline_store/ray.py | 635 ++++++++++++++++++ .../ray_offline_store/tests/__init__.py | 0 .../tests/test_ray_integration.py | 146 ++++ .../contrib/ray_repo_configuration.py | 113 ++++ sdk/python/feast/repo_config.py | 1 + .../feature_repos/repo_configuration.py | 4 + 9 files changed, 1148 insertions(+) create mode 100644 docs/reference/offline-stores/ray.md create mode 100644 sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/__init__.py create mode 100644 sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py create mode 100644 sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/__init__.py create mode 100644 sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/test_ray_integration.py create mode 100644 sdk/python/feast/infra/offline_stores/contrib/ray_repo_configuration.py diff --git a/Makefile b/Makefile index b8a34855fbf..ee12d9bba26 100644 --- a/Makefile +++ b/Makefile @@ -302,6 +302,28 @@ test-python-universal-postgres-offline: ## Run Python Postgres integration tests not test_spark" \ sdk/python/tests +test-python-universal-ray-offline: ## Run Python Ray offline store integration tests + PYTHONPATH='.' \ + FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.offline_stores.contrib.ray_repo_configuration \ + PYTEST_PLUGINS=sdk.python.feast.infra.offline_stores.contrib.ray_offline_store.tests \ + python -m pytest -n 8 --integration \ + -k "not test_historical_retrieval_with_validation and \ + not test_historical_features_persisting and \ + not test_universal_cli and \ + not test_go_feature_server and \ + not test_feature_logging and \ + not test_reorder_columns and \ + not test_logged_features_validation and \ + not test_lambda_materialization_consistency and \ + not test_offline_write and \ + not test_push_features_to_offline_store and \ + not gcs_registry and \ + not s3_registry and \ + not test_snowflake and \ + not test_spark and \ + not test_trino" \ + sdk/python/tests + test-python-universal-postgres-online: ## Run Python Postgres integration tests PYTHONPATH='.' \ FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.postgres_online_store.postgres_repo_configuration \ diff --git a/docs/reference/offline-stores/ray.md b/docs/reference/offline-stores/ray.md new file mode 100644 index 00000000000..bc69dcc7900 --- /dev/null +++ b/docs/reference/offline-stores/ray.md @@ -0,0 +1,227 @@ +# Ray Offline Store (contrib) + +The Ray offline store is a distributed offline store implementation that leverages [Ray](https://www.ray.io/) for distributed data processing. It's particularly useful for large-scale feature engineering and retrieval operations. + +## Overview + +The Ray offline store provides: +- Distributed data processing using Ray +- Support for both local and cluster modes +- Efficient data loading and processing +- Integration with various storage backends (local files, S3, etc.) +- Support for scalable batch materialization +- Saved dataset persistence for data analysis and model training + +## Configuration + +The Ray offline store can be configured in your `feature_store.yaml` file: + +```yaml +project: my_project +registry: data/registry.db +provider: local +offline_store: + type: ray + storage_path: data/ray_storage # Optional: Path for materialized data + ray_address: localhost:10001 # Optional: Ray cluster address + use_ray_cluster: false # Optional: Whether to use Ray cluster +``` + +### Configuration Options + +| Option | Type | Required | Description | +|--------|------|----------|-------------| +| `type` | string | Yes | Must be `feast.offline_stores.ray.RayOfflineStore` | +| `storage_path` | string | No | Path for storing materialized data (e.g., "s3://my-bucket/data") | +| `ray_address` | string | No | Address of the Ray cluster (e.g., "localhost:10001") | +| `use_ray_cluster` | boolean | No | Whether to use Ray cluster mode (default: false) | + +## Usage Examples + +### Basic Usage + +```python +from feast import FeatureStore, FeatureView, FileSource +from feast.types import Float32, Int64 +from datetime import timedelta + +# Define a feature view +driver_stats = FeatureView( + name="driver_stats", + entities=["driver_id"], + ttl=timedelta(days=1), + online=True, + source=FileSource( + path="data/driver_stats.parquet", + timestamp_field="event_timestamp", + ), + schema=[ + ("driver_id", Int64), + ("avg_daily_trips", Float32), + ], +) + +# Initialize feature store +store = FeatureStore("feature_store.yaml") + +# Get historical features +entity_df = pd.DataFrame({ + "driver_id": [1, 2, 3], + "event_timestamp": [datetime.now()] * 3 +}) + +features = store.get_historical_features( + entity_df=entity_df, + features=[ + "driver_stats:avg_daily_trips" + ] +).to_df() +``` + +### Saved Dataset Persistence + +The Ray offline store supports persisting datasets for later analysis and model training: + +```python +from feast import FeatureStore +from feast.infra.offline_stores.file_source import SavedDatasetFileStorage + +# Initialize feature store +store = FeatureStore("feature_store.yaml") + +# Get historical features +entity_df = pd.DataFrame({ + "driver_id": [1, 2, 3, 4, 5], + "event_timestamp": [datetime.now()] * 5 +}) + +# Create a retrieval job +job = store.get_historical_features( + entity_df=entity_df, + features=[ + "driver_stats:avg_daily_trips", + "driver_stats:total_trips" + ] +) + +# Create storage destination +storage = SavedDatasetFileStorage(path="data/training_dataset.parquet") + +# Persist the dataset +job.persist(storage, allow_overwrite=False) + +# Create a saved dataset in the registry +saved_dataset = store.create_saved_dataset( + from_=job, + name="driver_training_dataset", + storage=storage, + tags={"purpose": "model_training", "version": "v1"} +) + +print(f"Saved dataset created: {saved_dataset.name}") +``` + +### Remote Storage Persistence + +You can persist datasets to remote storage for distributed access: + +```python +# Persist to S3 +s3_storage = SavedDatasetFileStorage(path="s3://my-bucket/datasets/driver_features.parquet") +job.persist(s3_storage, allow_overwrite=True) + +# Persist to Google Cloud Storage +gcs_storage = SavedDatasetFileStorage(path="gs://my-project-bucket/datasets/driver_features.parquet") +job.persist(gcs_storage, allow_overwrite=True) + +# Persist to HDFS +hdfs_storage = SavedDatasetFileStorage(path="hdfs://namenode:8020/datasets/driver_features.parquet") +job.persist(hdfs_storage, allow_overwrite=True) +``` + +### Retrieving Saved Datasets + +You can retrieve previously saved datasets: + +```python +# Retrieve a saved dataset +saved_dataset = store.get_saved_dataset("driver_training_dataset") + +# Convert to different formats +df = saved_dataset.to_df() # Pandas DataFrame +arrow_table = saved_dataset.to_arrow() # PyArrow Table + +# Get dataset metadata +print(f"Dataset features: {saved_dataset.features}") +print(f"Join keys: {saved_dataset.join_keys}") +print(f"Min timestamp: {saved_dataset.min_event_timestamp}") +print(f"Max timestamp: {saved_dataset.max_event_timestamp}") +``` + +### Batch Materialization with Persistence + +Combine batch materialization with dataset persistence: + +```python +from datetime import datetime, timedelta + +# Materialize features for the last 30 days +store.materialize( + start_date=datetime.now() - timedelta(days=30), + end_date=datetime.now(), + feature_views=["driver_stats"] +) + +# Get historical features for the materialized period +entity_df = pd.DataFrame({ + "driver_id": list(range(1, 1001)), # 1000 drivers + "event_timestamp": [datetime.now()] * 1000 +}) + +job = store.get_historical_features( + entity_df=entity_df, + features=["driver_stats:avg_daily_trips"] +) + +# Persist to remote storage for distributed access +remote_storage = SavedDatasetFileStorage( + path="s3://my-bucket/large_datasets/driver_features_30d.parquet" +) +job.persist(remote_storage, allow_overwrite=True) +``` + +### Using Ray Cluster + +To use Ray in cluster mode: + +1. Start a Ray cluster: +```bash +ray start --head --port=10001 +``` + +2. Configure your `feature_store.yaml`: +```yaml +offline_store: + type: ray + ray_address: localhost:10001 + use_ray_cluster: true +``` + +### Remote Storage + +For large-scale materialization, you can use remote storage: + +```yaml +offline_store: + type: ray + storage_path: s3://my-bucket/features +``` + +```python +# Materialize features to remote storage +store.materialize( + start_date=datetime.now() - timedelta(days=7), + end_date=datetime.now(), + feature_views=["driver_stats"] +) +``` diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/__init__.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py new file mode 100644 index 00000000000..11ad1341236 --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py @@ -0,0 +1,635 @@ +import os +import uuid +from datetime import datetime +from pathlib import Path +from typing import Any, Callable, Iterable, List, Literal, Optional, Tuple, Union + +import fsspec +import numpy as np +import pandas as pd +import pyarrow as pa +import ray +import ray.data +from ray.data import Dataset +from ray.data.context import DatasetContext + +from feast.data_source import DataSource +from feast.errors import ( + RequestDataNotFoundInEntityDfException, + SavedDatasetLocationAlreadyExists, +) +from feast.feature_logging import LoggingConfig, LoggingSource +from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView +from feast.infra.offline_stores.file_source import FileSource, SavedDatasetFileStorage +from feast.infra.offline_stores.offline_store import ( + OfflineStore, + RetrievalJob, + RetrievalMetadata, +) +from feast.infra.offline_stores.offline_utils import get_expected_join_keys +from feast.infra.registry.base_registry import BaseRegistry +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.repo_config import FeastConfigBaseModel, RepoConfig +from feast.saved_dataset import SavedDatasetStorage, ValidationReference +from feast.utils import _get_column_names, _utc_now, make_df_tzaware + + +class RayRetrievalJob(RetrievalJob): + def __init__( + self, + dataset_or_callable: Union[Dataset, Callable[[], Dataset]], + staging_location: Optional[str] = None, + ): + self._dataset_or_callable = dataset_or_callable + self._staging_location = staging_location + self._cached_dataset: Optional[Dataset] = None + self._metadata: Optional[RetrievalMetadata] = None + self._full_feature_names: bool = False + self._on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None + + def _resolve(self) -> Any: + if callable(self._dataset_or_callable): + result = self._dataset_or_callable() + else: + result = self._dataset_or_callable + return result + + def to_df( + self, + validation_reference: Optional[ValidationReference] = None, + timeout: Optional[int] = None, + ) -> pd.DataFrame: + result = self._resolve() + if isinstance(result, pd.DataFrame): + return result + return result.to_pandas() + + def to_arrow( + self, + validation_reference: Optional[ValidationReference] = None, + timeout: Optional[int] = None, + ) -> pa.Table: + result = self._resolve() + if isinstance(result, pd.DataFrame): + return pa.Table.from_pandas(result) + # For Ray Dataset, convert to pandas first then to arrow + return pa.Table.from_pandas(result.to_pandas()) + + def to_remote_storage(self) -> list[str]: + if not self._staging_location: + raise ValueError("Staging location must be set for remote materialization.") + try: + ds = self._resolve() + RayOfflineStore._ensure_ray_initialized() + output_uri = os.path.join(self._staging_location, str(uuid.uuid4())) + ds.write_parquet(output_uri) + return [output_uri] + except Exception as e: + raise RuntimeError(f"Failed to write to remote storage: {e}") + + @property + def metadata(self) -> Optional[RetrievalMetadata]: + """Return metadata information about retrieval.""" + return self._metadata + + @property + def full_feature_names(self) -> bool: + return self._full_feature_names + + @property + def on_demand_feature_views(self) -> List[OnDemandFeatureView]: + return self._on_demand_feature_views or [] + + def to_sql(self) -> str: + raise NotImplementedError("SQL export not supported for Ray offline store") + + def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: + return self._resolve().to_pandas() + + def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table: + result = self._resolve() + if isinstance(result, pd.DataFrame): + return pa.Table.from_pandas(result) + # For Ray Dataset, convert to pandas first then to arrow + return pa.Table.from_pandas(result.to_pandas()) + + def persist( + self, + storage: SavedDatasetStorage, + allow_overwrite: Optional[bool] = False, + timeout: Optional[int] = None, + ) -> str: + """Persist the dataset to storage.""" + + if not isinstance(storage, SavedDatasetFileStorage): + raise ValueError( + f"Ray offline store only supports SavedDatasetFileStorage, got {type(storage)}" + ) + destination_path = storage.file_options.uri + if not destination_path.startswith(("s3://", "gs://", "hdfs://")): + if not allow_overwrite and os.path.exists(destination_path): + raise SavedDatasetLocationAlreadyExists(location=destination_path) + try: + ds = self._resolve() + if not destination_path.startswith(("s3://", "gs://", "hdfs://")): + os.makedirs(os.path.dirname(destination_path), exist_ok=True) + ds.write_parquet(destination_path) + return destination_path + except Exception as e: + raise RuntimeError(f"Failed to persist dataset to {destination_path}: {e}") + + +class RayOfflineStoreConfig(FeastConfigBaseModel): + type: Literal[ + "feast.offline_stores.contrib.ray_offline_store.ray.RayOfflineStore", "ray" + ] = "ray" + storage_path: Optional[str] = None + ray_address: Optional[str] = None + use_ray_cluster: Optional[bool] = False + + +class RayOfflineStore(OfflineStore): + def __init__(self): + self._staging_location: Optional[str] = None + self._ray_initialized: bool = False + + @staticmethod + def _ensure_ray_initialized(config: Optional[RepoConfig] = None): + """Ensure Ray is initialized with proper configuration.""" + if not ray.is_initialized(): + if config and hasattr(config, "offline_store"): + ray_config = config.offline_store + if isinstance(ray_config, RayOfflineStoreConfig): + if ray_config.use_ray_cluster and ray_config.ray_address: + ray.init( + address=ray_config.ray_address, + ignore_reinit_error=True, + include_dashboard=False, + ) + else: + ray.init( + _node_ip_address=os.getenv("RAY_NODE_IP", "127.0.0.1"), + num_cpus=os.cpu_count() or 4, + ignore_reinit_error=True, + include_dashboard=False, + ) + else: + ray.init(ignore_reinit_error=True) + else: + ray.init(ignore_reinit_error=True) + + ctx = DatasetContext.get_current() + ctx.shuffle_strategy = "sort" + ctx.enable_tensor_extension_casting = False + + def _init_ray(self, config: RepoConfig): + ray_config = config.offline_store + assert isinstance(ray_config, RayOfflineStoreConfig) + + self._ensure_ray_initialized(config) + + def _get_source_path(self, source: DataSource, config: RepoConfig) -> str: + if not isinstance(source, FileSource): + raise ValueError("RayOfflineStore currently only supports FileSource") + repo_path = getattr(config, "repo_path", None) + uri = FileSource.get_uri_for_file_path(repo_path, source.path) + return uri + + @staticmethod + def _create_filtered_dataset( + source_path: str, + timestamp_field: str, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + ) -> Dataset: + """Helper method to create a filtered dataset based on timestamp range.""" + ds = ray.data.read_parquet(source_path) + + try: + col_names = ds.schema().names + if timestamp_field not in col_names: + raise ValueError( + f"Timestamp field '{timestamp_field}' not found in columns: {col_names}" + ) + except Exception as e: + raise ValueError(f"Failed to get dataset schema: {e}") + + if start_date or end_date: + try: + if start_date and end_date: + filtered_ds = ds.filter( + lambda row: start_date <= row[timestamp_field] <= end_date + ) + elif start_date: + filtered_ds = ds.filter( + lambda row: row[timestamp_field] >= start_date + ) + elif end_date: + filtered_ds = ds.filter( + lambda row: row[timestamp_field] <= end_date + ) + else: + return ds + + return filtered_ds + except Exception as e: + raise RuntimeError(f"Failed to filter by timestamp: {e}") + + return ds + + @staticmethod + def get_historical_features( + config: RepoConfig, + feature_views: List[FeatureView], + feature_refs: List[str], + entity_df: Union[pd.DataFrame, str], + registry: BaseRegistry, + project: str, + full_feature_names: bool = False, + ) -> RetrievalJob: + store = RayOfflineStore() + store._init_ray(config) + + # Load entity_df + original_entity_df = ( + pd.read_csv(entity_df) if isinstance(entity_df, str) else entity_df.copy() + ) + result_df = make_df_tzaware(original_entity_df.copy()) + if "event_timestamp" in result_df.columns: + result_df["event_timestamp"] = pd.to_datetime( + result_df["event_timestamp"], utc=True, errors="coerce" + ).dt.floor("s") + + # Parse feature_refs and get ODFVs + on_demand_feature_views = OnDemandFeatureView.get_requested_odfvs( + feature_refs, project, registry + ) + + # --- Request Data Validation for ODFVs --- + for odfv in on_demand_feature_views: + odfv_request_data_schema = odfv.get_request_data_schema() + for feature_name in odfv_request_data_schema.keys(): + if feature_name not in original_entity_df.columns: + raise RequestDataNotFoundInEntityDfException( + feature_name=feature_name, + feature_view_name=odfv.name, + ) + + # Collect all join keys from feature views + all_join_keys = get_expected_join_keys(project, feature_views, registry) + if "event_timestamp" in result_df.columns: + all_join_keys.add("event_timestamp") + + # Keep only relevant entity columns and timestamp + result_df = result_df[ + [col for col in result_df.columns if col in all_join_keys] + ] + + requested_feature_columns = [] + added_dummy_columns = set() + + # Join each feature view + for fv in feature_views: + # Only process feature views that are referenced + fv_feature_refs = [ + ref for ref in feature_refs if ref.startswith(fv.name + ":") + ] + if not fv_feature_refs: + continue + + # Get join keys, feature names, timestamp, created timestamp + entities = fv.entities or [] + entity_objs = [registry.get_entity(e, project) for e in entities] + join_keys, feature_names, timestamp_field, created_col = _get_column_names( + fv, entity_objs + ) + if not join_keys: + join_keys = [DUMMY_ENTITY_ID] + + # Only add features that are actually requested in feature_refs + requested_feats = [ref.split(":", 1)[1] for ref in fv_feature_refs] + + # --- Error for Missing Features --- + available_feature_names = [f.name for f in fv.features] + missing_feats = [ + f for f in requested_feats if f not in available_feature_names + ] + if missing_feats: + raise KeyError( + f"Requested features {missing_feats} not found in feature view '{fv.name}' (available: {available_feature_names})" + ) + + for feat in requested_feats: + col_name = f"{fv.name}__{feat}" if full_feature_names else feat + requested_feature_columns.append(col_name) + + # Read feature data + source_path = store._get_source_path(fv.batch_source, config) + if not source_path: + raise ValueError(f"Missing batch source for FV {fv.name}") + feature_ds = ray.data.read_parquet(str(source_path)) + feature_df = feature_ds.to_pandas() + feature_df = make_df_tzaware(feature_df) + if timestamp_field in feature_df.columns: + feature_df[timestamp_field] = pd.to_datetime( + feature_df[timestamp_field], utc=True, errors="coerce" + ).dt.floor("s") + + # Ensure join keys exist in both entity and feature dataframe + for k in join_keys: + if k not in result_df.columns: + result_df[k] = DUMMY_ENTITY_VAL + added_dummy_columns.add(k) + if k not in feature_df.columns: + feature_df[k] = DUMMY_ENTITY_VAL + + if ( + timestamp_field not in result_df.columns + and "event_timestamp" in result_df.columns + ): + result_df[timestamp_field] = result_df["event_timestamp"] + + # Align join key dtypes before merge + for k in join_keys: + if k in result_df.columns and k in feature_df.columns: + feature_df[k] = feature_df[k].astype(result_df[k].dtype) + + # Deduplicate feature values (avoid list columns in keys) + dedup_keys = join_keys + [timestamp_field] + if created_col and created_col in feature_df.columns: + feature_df = feature_df.sort_values(by=dedup_keys + [created_col]) + feature_df = feature_df.groupby(dedup_keys, as_index=False).last() + else: + feature_df = feature_df.sort_values(by=dedup_keys) + feature_df = feature_df.drop_duplicates(subset=dedup_keys, keep="last") + + # Select only requested features that exist in feature_df + existing_feats = [f for f in requested_feats if f in feature_df.columns] + cols_to_keep = join_keys + [timestamp_field] + existing_feats + feature_df = feature_df[cols_to_keep] + + # Join into result_df + result_df = result_df.merge( + feature_df, + how="inner", + on=join_keys + + ([timestamp_field] if timestamp_field in result_df.columns else []), + ) + + # Handle full feature names + if full_feature_names: + result_df = result_df.rename( + columns={ + f: f"{fv.name}__{f}" + for f in existing_feats + if f in result_df.columns + } + ) + + # Re-attach original entity columns + for col in original_entity_df.columns: + if col not in result_df.columns: + result_df[col] = original_entity_df[col] + + # Ensure event_timestamp is present + if ( + "event_timestamp" not in result_df.columns + and "event_timestamp" in original_entity_df.columns + ): + result_df["event_timestamp"] = pd.to_datetime( + original_entity_df["event_timestamp"], utc=True, errors="coerce" + ).dt.floor("s") + + if ( + "event_timestamp" not in result_df.columns + and timestamp_field in result_df.columns + ): + result_df["event_timestamp"] = result_df[timestamp_field] + + # Drop dummy entity columns + for dummy_col in added_dummy_columns: + if dummy_col in result_df.columns: + result_df = result_df.drop(columns=[dummy_col]) + + # Reorder columns: entity + timestamp + features (in requested order) + entity_columns = [ + c for c in original_entity_df.columns if c != "event_timestamp" + ] + # Build the list of output feature columns in the correct order + output_feature_columns = [] + for ref in feature_refs: + fv_name, feat = ref.split(":", 1) + col_name = f"{fv_name}__{feat}" if full_feature_names else feat + output_feature_columns.append(col_name) + + # Ensure all requested features are present, fill with NaN if missing + for col in output_feature_columns: + if col not in result_df.columns: + result_df[col] = np.nan + + final_columns = entity_columns + ["event_timestamp"] + output_feature_columns + result_df = result_df.reindex(columns=final_columns) + + # Convert list/numpy.ndarray columns to tuples for deduplication + def make_hashable_for_dedup(df, columns): + for col in columns: + if col in df.columns: + if df[col].apply(lambda x: isinstance(x, (np.ndarray, list))).any(): + df[col] = df[col].apply( + lambda x: tuple(x) + if isinstance(x, (np.ndarray, list)) + else x + ) + return df + + list_columns = [ + col + for col in final_columns + if col in result_df.columns + and result_df[col].apply(lambda x: isinstance(x, (np.ndarray, list))).any() + ] + result_df = make_hashable_for_dedup(result_df, list_columns) + + # Deduplicate + result_df = result_df.drop_duplicates().reset_index(drop=True) + + # Convert tuple columns back to lists + for col in list_columns: + if col in result_df.columns: + result_df[col] = result_df[col].apply( + lambda x: list(x) if isinstance(x, tuple) else x + ) + + # Return retrieval job + storage_path = config.offline_store.storage_path + if not storage_path: + raise ValueError("Storage path must be set in config") + + job = RayRetrievalJob(result_df, staging_location=storage_path) + job._full_feature_names = full_feature_names + job._on_demand_feature_views = on_demand_feature_views + return job + + def validate_data_source( + self, + config: RepoConfig, + data_source: DataSource, + ): + """Validates the underlying data source.""" + self._init_ray(config) + data_source.validate(config=config) + + def get_table_column_names_and_types_from_data_source( + self, + config: RepoConfig, + data_source: DataSource, + ) -> Iterable[Tuple[str, str]]: + """Returns the list of column names and raw column types for a DataSource.""" + return data_source.get_table_column_names_and_types(config=config) + + def supports_remote_storage_export(self) -> bool: + """Check if remote storage export is supported.""" + return self._staging_location is not None + + @staticmethod + def pull_latest_from_table_or_query( + config: RepoConfig, + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_field: str, + created_timestamp_column: Optional[str], + start_date: datetime, + end_date: datetime, + ) -> RetrievalJob: + store = RayOfflineStore() + store._init_ray(config) + + source_path = store._get_source_path(data_source, config) + + def _load(): + try: + return RayOfflineStore._create_filtered_dataset( + source_path, timestamp_field, start_date, end_date + ) + except Exception as e: + raise RuntimeError(f"Failed to load data from {source_path}: {e}") + + return RayRetrievalJob( + _load, staging_location=config.offline_store.storage_path + ) + + @staticmethod + def pull_all_from_table_or_query( + config: RepoConfig, + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_field: str, + created_timestamp_column: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + ) -> RetrievalJob: + store = RayOfflineStore() + store._init_ray(config) + + source_path = store._get_source_path(data_source, config) + + fs, path_in_fs = fsspec.core.url_to_fs(source_path) + if not fs.exists(path_in_fs): + raise FileNotFoundError(f"Parquet path does not exist: {source_path}") + + def _load(): + try: + return RayOfflineStore._create_filtered_dataset( + source_path, timestamp_field, start_date, end_date + ) + except Exception as e: + raise RuntimeError(f"Failed to load data from {source_path}: {e}") + + return RayRetrievalJob( + _load, staging_location=config.offline_store.storage_path + ) + + @staticmethod + def write_logged_features( + config: RepoConfig, + data: Union[pa.Table, Path], + source: LoggingSource, + logging_config: LoggingConfig, + registry: BaseRegistry, + ) -> None: + RayOfflineStore._ensure_ray_initialized(config) + + repo_path = getattr(config, "repo_path", None) or os.getcwd() + + # Get source path and resolve URI + source_path = getattr(source, "file_path", None) + if not source_path: + raise ValueError("LoggingSource must have a file_path attribute") + + path = FileSource.get_uri_for_file_path(repo_path, source_path) + + try: + if isinstance(data, Path): + ds = ray.data.read_parquet(str(data)) + else: + ds = ray.data.from_pandas(pa.Table.to_pandas(data)) + + ds.materialize() + + if not path.startswith(("s3://", "gs://")): + os.makedirs(os.path.dirname(path), exist_ok=True) + + ds.write_parquet(path) + except Exception as e: + raise RuntimeError(f"Failed to write logged features: {e}") + + @staticmethod + def offline_write_batch( + config: RepoConfig, + feature_view: FeatureView, + table: pa.Table, + progress: Optional[Callable[[int], Any]] = None, + ) -> None: + RayOfflineStore._ensure_ray_initialized(config) + + repo_path = getattr(config, "repo_path", None) or os.getcwd() + ray_config = config.offline_store + assert isinstance(ray_config, RayOfflineStoreConfig) + base_storage_path = ray_config.storage_path or "/tmp/ray-storage" + + batch_source_path = getattr(feature_view.batch_source, "file_path", None) + if not batch_source_path: + batch_source_path = f"{feature_view.name}/push_{_utc_now()}.parquet" + + feature_path = FileSource.get_uri_for_file_path(repo_path, batch_source_path) + storage_path = FileSource.get_uri_for_file_path(repo_path, base_storage_path) + + feature_dir = os.path.dirname(feature_path) + if not feature_dir.startswith(("s3://", "gs://")): + os.makedirs(feature_dir, exist_ok=True) + if not storage_path.startswith(("s3://", "gs://")): + os.makedirs(os.path.dirname(storage_path), exist_ok=True) + + df = table.to_pandas() + ds = ray.data.from_pandas(df) + ds.materialize() + ds.write_parquet(feature_dir) + + @staticmethod + def create_saved_dataset_destination( + config: RepoConfig, + name: str, + path: Optional[str] = None, + ) -> SavedDatasetStorage: + """Create a saved dataset destination for Ray offline store.""" + + if path is None: + # Use default path based on config + ray_config = config.offline_store + assert isinstance(ray_config, RayOfflineStoreConfig) + base_storage_path = ray_config.storage_path or "/tmp/ray-storage" + path = f"{base_storage_path}/saved_datasets/{name}.parquet" + + return SavedDatasetFileStorage(path=path) diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/__init__.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/test_ray_integration.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/test_ray_integration.py new file mode 100644 index 00000000000..5ab82f8ef47 --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/test_ray_integration.py @@ -0,0 +1,146 @@ +import pandas as pd +import pytest + +from feast.utils import _utc_now +from tests.integration.feature_repos.repo_configuration import ( + construct_universal_feature_views, +) +from tests.integration.feature_repos.universal.entities import driver + + +@pytest.mark.integration +@pytest.mark.universal_offline_stores +def test_ray_offline_store_basic_write_and_read(environment, universal_data_sources): + """Test basic write and read functionality with Ray offline store.""" + store = environment.feature_store + _, _, data_sources = universal_data_sources + feature_views = construct_universal_feature_views(data_sources) + driver_fv = feature_views.driver + store.apply([driver(), driver_fv]) + + now = _utc_now() + ts = pd.Timestamp(now).round("ms") + + # Write data to offline store + df_to_write = pd.DataFrame.from_dict( + { + "event_timestamp": [ts, ts], + "driver_id": [1001, 1002], + "conv_rate": [0.1, 0.2], + "acc_rate": [0.9, 0.8], + "avg_daily_trips": [10, 20], + "created": [ts, ts], + }, + ) + + store.write_to_offline_store( + driver_fv.name, df_to_write, allow_registry_cache=False + ) + + # Read data back + entity_df = pd.DataFrame({"driver_id": [1001, 1002], "event_timestamp": [ts, ts]}) + + result_df = store.get_historical_features( + entity_df=entity_df, + features=[ + "driver_stats:conv_rate", + "driver_stats:acc_rate", + "driver_stats:avg_daily_trips", + ], + full_feature_names=False, + ).to_df() + + assert len(result_df) == 2 + assert "conv_rate" in result_df.columns + assert "acc_rate" in result_df.columns + assert "avg_daily_trips" in result_df.columns + + +@pytest.mark.integration +@pytest.mark.universal_offline_stores +@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: f"full:{v}") +def test_ray_offline_store_historical_features( + environment, universal_data_sources, full_feature_names +): + """Test historical features retrieval with Ray offline store.""" + store = environment.feature_store + + (entities, datasets, data_sources) = universal_data_sources + feature_views = construct_universal_feature_views(data_sources) + + entity_df_with_request_data = datasets.entity_df.copy(deep=True) + entity_df_with_request_data["val_to_add"] = [ + i for i in range(len(entity_df_with_request_data)) + ] + + store.apply( + [ + driver(), + *feature_views.values(), + ] + ) + + job = store.get_historical_features( + entity_df=entity_df_with_request_data, + features=[ + "driver_stats:conv_rate", + "driver_stats:avg_daily_trips", + "customer_profile:current_balance", + "conv_rate_plus_100:conv_rate_plus_100", + ], + full_feature_names=full_feature_names, + ) + + # Test DataFrame conversion + result_df = job.to_df() + assert len(result_df) > 0 + assert "event_timestamp" in result_df.columns + + # Test Arrow conversion + result_table = job.to_arrow().to_pandas() + assert len(result_table) > 0 + assert "event_timestamp" in result_table.columns + + +@pytest.mark.integration +@pytest.mark.universal_offline_stores +def test_ray_offline_store_persist(environment, universal_data_sources): + """Test dataset persistence with Ray offline store.""" + store = environment.feature_store + + (entities, datasets, data_sources) = universal_data_sources + feature_views = construct_universal_feature_views(data_sources) + + entity_df_with_request_data = datasets.entity_df.copy(deep=True) + entity_df_with_request_data["val_to_add"] = [ + i for i in range(len(entity_df_with_request_data)) + ] + + store.apply( + [ + driver(), + *feature_views.values(), + ] + ) + + job = store.get_historical_features( + entity_df=entity_df_with_request_data, + features=[ + "driver_stats:conv_rate", + "customer_profile:current_balance", + ], + full_feature_names=False, + ) + + # Test persisting the dataset + from feast.saved_dataset import SavedDatasetFileStorage + + storage = SavedDatasetFileStorage(path="data/test_saved_dataset.parquet") + saved_path = job.persist(storage, allow_overwrite=True) + + assert saved_path == "data/test_saved_dataset.parquet" + + # Verify the saved dataset exists + import os + + assert os.path.exists(saved_path) diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_repo_configuration.py b/sdk/python/feast/infra/offline_stores/contrib/ray_repo_configuration.py new file mode 100644 index 00000000000..ea6cdbaa3c8 --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_repo_configuration.py @@ -0,0 +1,113 @@ +import os +import tempfile +from typing import Any, Dict, Optional + +from sdk.python.feast.infra.offline_stores.contrib.ray_offline_store.ray import ( + RayOfflineStoreConfig, +) + +from feast.data_format import ParquetFormat +from feast.data_source import DataSource +from feast.feature_logging import LoggingDestination +from feast.infra.offline_stores.file_source import ( + FileLoggingDestination, + FileSource, + SavedDatasetFileStorage, +) +from feast.repo_config import FeastConfigBaseModel +from feast.saved_dataset import SavedDatasetStorage +from tests.integration.feature_repos.integration_test_repo_config import ( + IntegrationTestRepoConfig, +) +from tests.integration.feature_repos.universal.data_source_creator import ( + DataSourceCreator, +) + + +class RayDataSourceCreator(DataSourceCreator): + def __init__(self, project_name: str, *args, **kwargs): + super().__init__(project_name, *args, **kwargs) + self.offline_store_config = RayOfflineStoreConfig( + type="ray", + storage_path="/tmp/ray-storage", + ray_address=None, + use_ray_cluster=False, + ) + self.files = [] + self.dirs = [] + + def create_offline_store_config(self) -> FeastConfigBaseModel: + return self.offline_store_config + + def create_data_source( + self, + df: Any, + destination_name: str, + created_timestamp_column: Optional[Any] = "created_ts", + field_mapping: Optional[Dict[str, str]] = None, + timestamp_field: Optional[str] = "ts", + ) -> DataSource: + # For Ray, we'll use parquet files as the underlying storage + destination_name = self.get_prefixed_table_name(destination_name) + + f = tempfile.NamedTemporaryFile( + prefix=f"{self.project_name}_{destination_name}", + suffix=".parquet", + delete=False, + ) + df.to_parquet(f.name) + self.files.append(f) + + return FileSource( + file_format=ParquetFormat(), + path=f.name, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + field_mapping=field_mapping or {"ts_1": "ts"}, + ) + + def get_prefixed_table_name(self, suffix: str) -> str: + return f"{self.project_name}.{suffix}" + + def create_saved_dataset_destination(self) -> SavedDatasetStorage: + d = tempfile.mkdtemp(prefix=self.project_name) + self.dirs.append(d) + return SavedDatasetFileStorage( + path=d, + file_format=ParquetFormat(), + ) + + def create_logged_features_destination(self) -> LoggingDestination: + d = tempfile.mkdtemp(prefix=self.project_name) + self.dirs.append(d) + return FileLoggingDestination(path=d) + + def teardown(self) -> None: + # Clean up any temporary files or resources + import shutil + + for f in self.files: + f.close() + try: + os.unlink(f.name) + except OSError: + pass + + for d in self.dirs: + if os.path.exists(d): + shutil.rmtree(d) + + def get_saved_dataset_data_source(self) -> Dict[str, str]: + return { + "type": "parquet", + "path": "data/saved_dataset.parquet", + } + + +# Define the full repo configurations for Ray offline store +FULL_REPO_CONFIGS = [ + IntegrationTestRepoConfig( + provider="local", + offline_store_creator=RayDataSourceCreator, + ), +] diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 545d5ba4c3a..24c9d30a028 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -101,6 +101,7 @@ "remote": "feast.infra.offline_stores.remote.RemoteOfflineStore", "couchbase.offline": "feast.infra.offline_stores.contrib.couchbase_offline_store.couchbase.CouchbaseColumnarOfflineStore", "clickhouse": "feast.infra.offline_stores.contrib.clickhouse_offline_store.clickhouse.ClickhouseOfflineStore", + "ray": "feast.infra.offline_stores.contrib.ray_offline_store.ray.RayOfflineStore", } FEATURE_SERVER_CONFIG_CLASS_FOR_TYPE = { diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index 24e611c4f33..89a13df69ed 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -27,6 +27,9 @@ FeatureLoggingConfig, ) from feast.infra.feature_servers.local_process.config import LocalFeatureServerConfig +from feast.infra.offline_stores.contrib.ray_repo_configuration import ( + RayDataSourceCreator, +) from feast.permissions.action import AuthzedAction from feast.permissions.auth_model import OidcClientAuthConfig from feast.permissions.permission import Permission @@ -137,6 +140,7 @@ ("local", RemoteOfflineStoreDataSourceCreator), ("local", RemoteOfflineOidcAuthStoreDataSourceCreator), ("local", RemoteOfflineTlsStoreDataSourceCreator), + ("local", RayDataSourceCreator), ] if os.getenv("FEAST_IS_LOCAL_TEST", "False") == "True": From 51c3426af557fd0264e210fd8acf757b2c5af4cb Mon Sep 17 00:00:00 2001 From: ntkathole Date: Fri, 27 Jun 2025 17:24:09 +0530 Subject: [PATCH 03/10] feat: Improved Ray distributed processing Signed-off-by: ntkathole --- Makefile | 8 +- docs/reference/offline-stores/ray.md | 111 +- .../contrib/ray_offline_store/__init__.py | 56 + .../contrib/ray_offline_store/ray.py | 1262 ++++++++++++++--- .../contrib/ray_repo_configuration.py | 11 +- 5 files changed, 1245 insertions(+), 203 deletions(-) diff --git a/Makefile b/Makefile index ee12d9bba26..e088de84e99 100644 --- a/Makefile +++ b/Makefile @@ -307,21 +307,17 @@ test-python-universal-ray-offline: ## Run Python Ray offline store integration t FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.offline_stores.contrib.ray_repo_configuration \ PYTEST_PLUGINS=sdk.python.feast.infra.offline_stores.contrib.ray_offline_store.tests \ python -m pytest -n 8 --integration \ + -m "not universal_online_stores and not benchmark" \ -k "not test_historical_retrieval_with_validation and \ - not test_historical_features_persisting and \ not test_universal_cli and \ not test_go_feature_server and \ not test_feature_logging and \ - not test_reorder_columns and \ not test_logged_features_validation and \ not test_lambda_materialization_consistency and \ - not test_offline_write and \ - not test_push_features_to_offline_store and \ not gcs_registry and \ not s3_registry and \ not test_snowflake and \ - not test_spark and \ - not test_trino" \ + not test_spark" \ sdk/python/tests test-python-universal-postgres-online: ## Run Python Postgres integration tests diff --git a/docs/reference/offline-stores/ray.md b/docs/reference/offline-stores/ray.md index bc69dcc7900..b0ba1e145c8 100644 --- a/docs/reference/offline-stores/ray.md +++ b/docs/reference/offline-stores/ray.md @@ -7,11 +7,28 @@ The Ray offline store is a distributed offline store implementation that leverag The Ray offline store provides: - Distributed data processing using Ray - Support for both local and cluster modes -- Efficient data loading and processing - Integration with various storage backends (local files, S3, etc.) - Support for scalable batch materialization - Saved dataset persistence for data analysis and model training +## Optimization Features + +### Intelligent Join Strategies + +The Ray offline store now includes intelligent join strategy selection: + +- **Broadcast Joins**: For small feature datasets (<100MB by default), data is stored in Ray's object store for efficient broadcasting +- **Distributed Windowed Joins**: For large datasets, uses time-based windowing for distributed point-in-time joins +- **Automatic Strategy Selection**: Chooses optimal join strategy based on dataset size and cluster resources + +### Resource Management + +The store automatically detects and optimizes for your Ray cluster: + +- **Auto-scaling**: Adjusts parallelism based on available CPU cores +- **Memory Optimization**: Configures buffer sizes based on available memory +- **Partition Optimization**: Calculates optimal partition sizes for your workload + ## Configuration The Ray offline store can be configured in your `feature_store.yaml` file: @@ -22,19 +39,30 @@ registry: data/registry.db provider: local offline_store: type: ray - storage_path: data/ray_storage # Optional: Path for materialized data - ray_address: localhost:10001 # Optional: Ray cluster address - use_ray_cluster: false # Optional: Whether to use Ray cluster + storage_path: data/ray_storage # Optional: Path for materialized data + ray_address: localhost:10001 # Optional: Ray cluster address + use_ray_cluster: false # Optional: Whether to use Ray cluster + # New optimization settings + broadcast_join_threshold_mb: 100 # Optional: Threshold for broadcast joins (MB) + enable_distributed_joins: true # Optional: Enable distributed join strategies + max_parallelism_multiplier: 2 # Optional: Max parallelism as multiple of CPU cores + target_partition_size_mb: 64 # Optional: Target partition size (MB) + window_size_for_joins: "1H" # Optional: Time window size for distributed joins ``` ### Configuration Options -| Option | Type | Required | Description | -|--------|------|----------|-------------| -| `type` | string | Yes | Must be `feast.offline_stores.ray.RayOfflineStore` | -| `storage_path` | string | No | Path for storing materialized data (e.g., "s3://my-bucket/data") | -| `ray_address` | string | No | Address of the Ray cluster (e.g., "localhost:10001") | -| `use_ray_cluster` | boolean | No | Whether to use Ray cluster mode (default: false) | +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `type` | string | Required | Must be `feast.offline_stores.contrib.ray_offline_store.ray.RayOfflineStore` or `ray` | +| `storage_path` | string | None | Path for storing materialized data (e.g., "s3://my-bucket/data") | +| `ray_address` | string | None | Address of the Ray cluster (e.g., "localhost:10001") | +| `use_ray_cluster` | boolean | false | Whether to use Ray cluster mode | +| `broadcast_join_threshold_mb` | int | 100 | Size threshold (MB) below which broadcast joins are used | +| `enable_distributed_joins` | boolean | true | Enable intelligent distributed join strategies | +| `max_parallelism_multiplier` | int | 2 | Maximum parallelism as multiple of CPU cores | +| `target_partition_size_mb` | int | 64 | Target size for data partitions (MB) | +| `window_size_for_joins` | string | "1H" | Time window size for distributed temporal joins | ## Usage Examples @@ -78,6 +106,43 @@ features = store.get_historical_features( ).to_df() ``` +### Optimized Configuration for Large Datasets + +```yaml +offline_store: + type: ray + storage_path: s3://my-bucket/feast-data + use_ray_cluster: true + ray_address: ray://head-node:10001 + # Optimize for large datasets + broadcast_join_threshold_mb: 50 # Smaller threshold for large clusters + max_parallelism_multiplier: 4 # Higher parallelism for more CPUs + target_partition_size_mb: 128 # Larger partitions for better throughput + window_size_for_joins: "30min" # Smaller windows for better distribution +``` + +### High-Performance Feature Retrieval + +```python +# For large-scale feature retrieval with millions of entities +large_entity_df = pd.DataFrame({ + "driver_id": range(1, 1000000), # 1M drivers + "event_timestamp": [datetime.now()] * 1000000 +}) + +# The Ray offline store will automatically: +# 1. Detect large dataset and use distributed joins +# 2. Partition data optimally across cluster +# 3. Use appropriate join strategy based on feature data size +features = store.get_historical_features( + entity_df=large_entity_df, + features=[ + "driver_stats:avg_daily_trips", + "driver_stats:total_distance" + ] +).to_df() +``` + ### Saved Dataset Persistence The Ray offline store supports persisting datasets for later analysis and model training: @@ -192,7 +257,7 @@ job.persist(remote_storage, allow_overwrite=True) ### Using Ray Cluster -To use Ray in cluster mode: +To use Ray in cluster mode for maximum performance: 1. Start a Ray cluster: ```bash @@ -205,6 +270,15 @@ offline_store: type: ray ray_address: localhost:10001 use_ray_cluster: true + # Cluster-optimized settings + max_parallelism_multiplier: 3 + target_partition_size_mb: 256 +``` + +3. For multiple worker nodes: +```bash +# On worker nodes +ray start --address='head-node-ip:10001' ``` ### Remote Storage @@ -225,3 +299,18 @@ store.materialize( feature_views=["driver_stats"] ) ``` + + +### Custom Optimization + +For specific workloads, you can fine-tune the configuration: + +```yaml +offline_store: + type: ray + # Fine-tuning for high-throughput scenarios + broadcast_join_threshold_mb: 200 # Larger broadcast threshold + max_parallelism_multiplier: 1 # Conservative parallelism + target_partition_size_mb: 512 # Larger partitions + window_size_for_joins: "2H" # Larger time windows +``` diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/__init__.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/__init__.py index e69de29bb2d..d0eb96bfcb2 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/__init__.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/__init__.py @@ -0,0 +1,56 @@ +""" +Ray offline store for Feast. + +This module provides distributed offline feature store functionality using Ray with +advanced optimization features for scalable feature retrieval. + +Key Features: +- Intelligent join strategy selection (broadcast vs. distributed) +- Resource-aware partitioning and parallelism +- Windowed temporal joins for large datasets +- Configurable performance tuning parameters +- Automatic cluster resource management + +Classes: +- RayOfflineStore: Main offline store implementation +- RayOfflineStoreConfig: Configuration with optimization settings +- RayRetrievalJob: Enhanced retrieval job with caching +- RayResourceManager: Cluster resource management +- RayDataProcessor: Optimized data processing operations + +Usage: +Configure in your feature_store.yaml: +```yaml +offline_store: + type: ray + storage_path: /path/to/storage + broadcast_join_threshold_mb: 100 + enable_distributed_joins: true + max_parallelism_multiplier: 2 + target_partition_size_mb: 64 + window_size_for_joins: "1H" +``` + +Performance Optimizations: +- Broadcast joins for small datasets (<100MB by default) +- Distributed windowed joins for large datasets +- Optimal partitioning based on cluster resources +- Memory-aware buffer sizing +- Lazy evaluation with caching +""" + +from .ray import ( + RayDataProcessor, + RayOfflineStore, + RayOfflineStoreConfig, + RayResourceManager, + RayRetrievalJob, +) + +__all__ = [ + "RayOfflineStore", + "RayOfflineStoreConfig", + "RayRetrievalJob", + "RayResourceManager", + "RayDataProcessor", +] diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py index 11ad1341236..39a43c7c001 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py @@ -1,8 +1,9 @@ +import logging import os import uuid from datetime import datetime from pathlib import Path -from typing import Any, Callable, Iterable, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union import fsspec import numpy as np @@ -26,28 +27,571 @@ RetrievalJob, RetrievalMetadata, ) -from feast.infra.offline_stores.offline_utils import get_expected_join_keys from feast.infra.registry.base_registry import BaseRegistry from feast.on_demand_feature_view import OnDemandFeatureView from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage, ValidationReference from feast.utils import _get_column_names, _utc_now, make_df_tzaware +logger = logging.getLogger(__name__) + + +class RayOfflineStoreConfig(FeastConfigBaseModel): + type: Literal[ + "feast.offline_stores.contrib.ray_offline_store.ray.RayOfflineStore", "ray" + ] = "ray" + storage_path: Optional[str] = None + ray_address: Optional[str] = None + use_ray_cluster: Optional[bool] = False + + # Optimization settings + broadcast_join_threshold_mb: Optional[int] = 100 + enable_distributed_joins: Optional[bool] = True + max_parallelism_multiplier: Optional[int] = 2 + target_partition_size_mb: Optional[int] = 64 + window_size_for_joins: Optional[str] = "1H" + + +class RayResourceManager: + """Manages Ray cluster resources for optimal performance.""" + + def __init__(self, config: Optional[RayOfflineStoreConfig] = None): + self.config = config or RayOfflineStoreConfig() + self.cluster_resources = ray.cluster_resources() + self.available_memory = self.cluster_resources.get( + "memory", 8 * 1024**3 + ) # 8GB default + self.available_cpus = int(self.cluster_resources.get("CPU", 4)) + self.num_nodes = len(ray.nodes()) if ray.is_initialized() else 1 + + def configure_ray_context(self): + """Configure Ray DatasetContext for optimal performance.""" + ctx = DatasetContext.get_current() + + # Set buffer sizes based on available memory + if self.available_memory > 32 * 1024**3: # 32GB + ctx.target_shuffle_buffer_size = 2 * 1024**3 # 2GB + ctx.target_max_block_size = 512 * 1024**2 # 512MB + else: + ctx.target_shuffle_buffer_size = 512 * 1024**2 # 512MB + ctx.target_max_block_size = 128 * 1024**2 # 128MB + + # Configure parallelism + ctx.min_parallelism = self.available_cpus + ctx.max_parallelism = ( + self.available_cpus * self.config.max_parallelism_multiplier + ) + + # Optimize for feature store workloads + ctx.shuffle_strategy = "sort" + ctx.enable_tensor_extension_casting = False + + logger.info( + f"Configured Ray context: {self.available_cpus} CPUs, " + f"{self.available_memory // 1024**3}GB memory, {self.num_nodes} nodes" + ) + + def estimate_optimal_partitions(self, dataset_size_bytes: int) -> int: + """Estimate optimal number of partitions for a dataset.""" + # Use configured target partition size + target_partition_size = (self.config.target_partition_size_mb or 64) * 1024**2 + size_based_partitions = max(1, dataset_size_bytes // target_partition_size) + + # Don't exceed configured max parallelism + max_partitions = self.available_cpus * ( + self.config.max_parallelism_multiplier or 2 + ) + + return min(size_based_partitions, max_partitions) + + def should_use_broadcast_join( + self, dataset_size_bytes: int, threshold_mb: Optional[int] = None + ) -> bool: + """Determine if dataset is small enough for broadcast join.""" + threshold = ( + threshold_mb + if threshold_mb is not None + else (self.config.broadcast_join_threshold_mb or 100) + ) + return dataset_size_bytes <= threshold * 1024**2 + + def estimate_processing_requirements( + self, dataset_size_bytes: int, operation_type: str + ) -> Dict[str, Any]: + """Estimate resource requirements for different operations.""" + + # Memory requirements (with safety margin) + memory_multiplier = { + "read": 1.2, # 20% overhead for reading + "join": 3.0, # 3x for join operations + "aggregate": 2.0, # 2x for aggregations + "shuffle": 2.5, # 2.5x for shuffling + } + + required_memory = dataset_size_bytes * memory_multiplier.get( + operation_type, 2.0 + ) + + return { + "required_memory": required_memory, + "optimal_partitions": self.estimate_optimal_partitions(dataset_size_bytes), + "can_fit_in_memory": required_memory <= self.available_memory * 0.8, + "should_broadcast": self.should_use_broadcast_join(dataset_size_bytes), + } + + +class RayDataProcessor: + """Optimized data processing with Ray for feature store operations.""" + + def __init__(self, resource_manager: RayResourceManager): + self.resource_manager = resource_manager + + def optimize_dataset_for_join(self, ds: Dataset, join_keys: List[str]) -> Dataset: + """Optimize dataset partitioning for join operations.""" + + # Estimate optimal partitions + dataset_size = ds.size_bytes() + optimal_partitions = self.resource_manager.estimate_optimal_partitions( + dataset_size + ) + + if not join_keys: + # For datasets without join keys, use simple repartitioning + return ds.repartition(num_blocks=optimal_partitions) + + # For datasets with join keys, use shuffle for better distribution + return ds.random_shuffle(num_blocks=optimal_partitions) + + def broadcast_join_features( + self, + entity_ds: Dataset, + feature_df: pd.DataFrame, + join_keys: List[str], + timestamp_field: str, + requested_feats: List[str], + full_feature_names: bool = False, + feature_view_name: Optional[str] = None, + original_join_keys: Optional[List[str]] = None, + ) -> Dataset: + """Perform broadcast join for small feature datasets.""" + + # Put feature data in Ray object store for efficient broadcasting + feature_ref = ray.put(feature_df) + + def join_batch_with_features(batch: pd.DataFrame) -> pd.DataFrame: + """Join a batch with broadcast feature data.""" + features = ray.get(feature_ref) + + logger.debug(f"Broadcast join - Features DataFrame shape: {features.shape}") + logger.debug( + f"Broadcast join - Features DataFrame columns: {list(features.columns)}" + ) + logger.debug(f"Broadcast join - Requested features: {requested_feats}") + logger.debug(f"Broadcast join - Join keys: {join_keys}") + logger.debug(f"Broadcast join - Timestamp field: {timestamp_field}") + logger.debug(f"Broadcast join - Batch DataFrame shape: {batch.shape}") + logger.debug( + f"Broadcast join - Batch DataFrame columns: {list(batch.columns)}" + ) + if feature_view_name: + logger.info( + f"Processing feature view {feature_view_name} with join keys {join_keys}" + ) + + # Select only required feature columns plus join keys and timestamp + # Use original join keys for filtering if provided (for entity mapping) + filter_join_keys = original_join_keys if original_join_keys else join_keys + feature_cols = [timestamp_field] + filter_join_keys + requested_feats + features_filtered = features[feature_cols].copy() + + logger.debug( + f"Broadcast join - Features filtered shape: {features_filtered.shape}" + ) + logger.debug( + f"Broadcast join - Features filtered columns: {list(features_filtered.columns)}" + ) + + # Ensure timestamp columns have compatible dtypes and precision + if timestamp_field in batch.columns: + batch[timestamp_field] = ( + pd.to_datetime(batch[timestamp_field], utc=True, errors="coerce") + .dt.floor("s") + .astype("datetime64[ns, UTC]") + ) + + if timestamp_field in features_filtered.columns: + features_filtered[timestamp_field] = ( + pd.to_datetime( + features_filtered[timestamp_field], utc=True, errors="coerce" + ) + .dt.floor("s") + .astype("datetime64[ns, UTC]") + ) + + if not join_keys: + # Temporal join without entity keys + batch_sorted = batch.sort_values(timestamp_field).reset_index(drop=True) + features_sorted = features_filtered.sort_values( + timestamp_field + ).reset_index(drop=True) + result = pd.merge_asof( + batch_sorted, + features_sorted, + on=timestamp_field, + direction="backward", + ) + else: + # Temporal join with entity keys + # Clean data first by removing NaN values + # Use original join keys for filtering feature dataset, mapped join keys for entity dataset + feature_join_keys = ( + original_join_keys if original_join_keys else join_keys + ) + + for key in join_keys: + if key not in batch.columns: + batch[key] = None + for key in feature_join_keys: + if key not in features_filtered.columns: + features_filtered[key] = None + + # Drop rows with NaN values in join keys or timestamp + batch_clean = batch.dropna(subset=join_keys + [timestamp_field]).copy() + features_clean = features_filtered.dropna( + subset=feature_join_keys + [timestamp_field] + ).copy() + + # If no valid data remains, return empty result + if batch_clean.empty or features_clean.empty: + return batch.head(0) # Return empty dataframe with same columns + + # Important: For merge_asof with 'by' parameter, sort by 'by' columns first, then by 'on' column + # Both DataFrames must be sorted identically + batch_sort_columns = join_keys + [timestamp_field] + features_sort_columns = feature_join_keys + [timestamp_field] + + try: + # For entity mapping, we need to manually join since merge_asof doesn't support different column names + if original_join_keys and original_join_keys != join_keys: + # Manual join for entity mapping + logger.info("Using manual join for entity mapping") + raise ValueError("Entity mapping requires manual join") + + # For multi-entity joins, use manual join to ensure correctness + if len(join_keys) > 1: + logger.info( + f"Using manual join for multi-entity join with keys: {join_keys}" + ) + raise ValueError("Multi-entity join requires manual join") + + # Sort both DataFrames consistently + batch_sorted = batch_clean.sort_values( + batch_sort_columns, ascending=True + ).reset_index(drop=True) + + features_sorted = features_clean.sort_values( + features_sort_columns, ascending=True + ).reset_index(drop=True) + + # Verify sorting (merge_asof requirement) + for key in join_keys: + if not batch_sorted[key].is_monotonic_increasing: + # If not monotonic, we need to handle this differently + logger.warning( + f"Join key {key} is not monotonic, using manual join" + ) + raise ValueError(f"Join key {key} is not monotonic") + if not features_sorted[key].is_monotonic_increasing: + logger.warning( + f"Feature join key {key} is not monotonic, using manual join" + ) + raise ValueError(f"Feature join key {key} is not monotonic") + + # Perform merge_asof + result = pd.merge_asof( + batch_sorted, + features_sorted, + on=timestamp_field, + by=join_keys, + direction="backward", + ) + logger.debug( + f"merge_asof succeeded for batch of size {len(batch_sorted)}" + ) + + except (ValueError, KeyError) as e: + # If merge_asof fails, implement manual point-in-time join + logger.warning( + f"merge_asof failed, implementing manual point-in-time join: {e}" + ) + + # Group by join keys and apply point-in-time logic manually + result_chunks = [] + + logger.debug( + f"Manual join - batch_clean shape: {batch_clean.shape}" + ) + logger.debug( + f"Manual join - features_clean shape: {features_clean.shape}" + ) + logger.debug(f"Manual join - join_keys: {join_keys}") + logger.debug( + f"Manual join - feature_join_keys: {feature_join_keys}" + ) + + for join_key_vals, entity_group in batch_clean.groupby(join_keys): + # Create dictionary for filtering features by join keys + if len(join_keys) == 1: + entity_key_filter = {join_keys[0]: join_key_vals} + else: + entity_key_filter = dict(zip(join_keys, join_key_vals)) + + # For entity mapping, map the entity keys to feature keys + feature_key_filter = {} + for i, entity_key in enumerate(join_keys): + feature_key = ( + feature_join_keys[i] + if i < len(feature_join_keys) + else entity_key + ) + feature_key_filter[feature_key] = entity_key_filter[ + entity_key + ] + + # Filter features for this join key group + feature_group = features_clean + for key, val in feature_key_filter.items(): + # Only filter if the key exists in the feature dataset + if key in feature_group.columns: + feature_group = feature_group[feature_group[key] == val] + logger.debug( + f"Filtered by {key}={val}: {len(feature_group)} rows remaining" + ) + else: + logger.warning( + f"Join key {key} not found in feature dataset columns: {list(feature_group.columns)}" + ) + # If the key is missing, we can't match, so return empty + feature_group = feature_group.iloc[0:0] + break + + if len(feature_group) == 0: + # No features found, add NaN columns + entity_result = entity_group.copy() + for feat in requested_feats: + if feat not in entity_result.columns: + entity_result[feat] = np.nan + result_chunks.append(entity_result) + else: + # Apply point-in-time logic: for each entity timestamp, find the latest feature + entity_result = entity_group.copy() + for feat in requested_feats: + if feat not in entity_result.columns: + entity_result[feat] = np.nan + + # For each row in entity group, find the latest feature value + for idx, entity_row in entity_group.iterrows(): + entity_ts = entity_row[timestamp_field] + # Find features with timestamp <= entity timestamp + valid_features = feature_group[ + feature_group[timestamp_field] <= entity_ts + ] + if len(valid_features) > 0: + # Sort by timestamp to ensure we get the latest feature + valid_features = valid_features.sort_values( + timestamp_field + ) + latest_feature = valid_features.iloc[-1] + # Update the result with feature values + for feat in requested_feats: + if feat in latest_feature: + entity_result.loc[idx, feat] = ( + latest_feature[feat] + ) + + result_chunks.append(entity_result) + + if result_chunks: + result = pd.concat(result_chunks, ignore_index=True) + else: + result = batch_clean.copy() + for feat in requested_feats: + if feat not in result.columns: + result[feat] = np.nan + + # Debug logging for join result + logger.debug(f"Join result shape: {result.shape}") + logger.debug(f"Join result columns: {list(result.columns)}") + + # Handle feature renaming if full_feature_names is True + if full_feature_names and feature_view_name: + for feat in requested_feats: + if feat in result.columns: + new_name = f"{feature_view_name}__{feat}" + result[new_name] = result[feat] + result = result.drop(columns=[feat]) + + return result + + return entity_ds.map_batches(join_batch_with_features, batch_format="pandas") + + def windowed_temporal_join( + self, + entity_ds: Dataset, + feature_ds: Dataset, + join_keys: List[str], + timestamp_field: str, + requested_feats: List[str], + window_size: Optional[str] = None, + full_feature_names: bool = False, + feature_view_name: Optional[str] = None, + original_join_keys: Optional[List[str]] = None, + ) -> Dataset: + """Perform windowed temporal join for large datasets.""" + + # Use configured window size if not provided + window_size = window_size or ( + self.resource_manager.config.window_size_for_joins or "1H" + ) + + # Step 1: Optimize both datasets for joining + entity_optimized = self.optimize_dataset_for_join(entity_ds, join_keys) + feature_optimized = self.optimize_dataset_for_join(feature_ds, join_keys) + + # Step 2: Add time windows and data source markers + entity_windowed = self._add_time_windows_and_source_marker( + entity_optimized, timestamp_field, "entity", window_size + ) + feature_windowed = self._add_time_windows_and_source_marker( + feature_optimized, timestamp_field, "feature", window_size + ) + + # Step 3: Union datasets for co-processing + combined_ds = entity_windowed.union(feature_windowed) + + # Step 4: Group by time window and join keys, then apply point-in-time logic + result_ds = combined_ds.map_batches( + self._apply_windowed_point_in_time_logic, + batch_format="pandas", + fn_kwargs={ + "timestamp_field": timestamp_field, + "join_keys": join_keys, + "requested_feats": requested_feats, + "full_feature_names": full_feature_names, + "feature_view_name": feature_view_name, + "original_join_keys": original_join_keys, + }, + ) + + return result_ds + + def _add_time_windows_and_source_marker( + self, ds: Dataset, timestamp_field: str, source_marker: str, window_size: str + ) -> Dataset: + """Add time windows and source markers to dataset.""" + + def add_window_and_source(batch: pd.DataFrame) -> pd.DataFrame: + batch = batch.copy() + batch["time_window"] = ( + pd.to_datetime(batch[timestamp_field]) + .dt.floor(window_size) + .astype("datetime64[ns, UTC]") + ) + batch["_data_source"] = source_marker + return batch + + return ds.map_batches(add_window_and_source, batch_format="pandas") + + def _apply_windowed_point_in_time_logic( + self, + batch: pd.DataFrame, + timestamp_field: str, + join_keys: List[str], + requested_feats: List[str], + full_feature_names: bool = False, + feature_view_name: Optional[str] = None, + original_join_keys: Optional[List[str]] = None, + ) -> pd.DataFrame: + """Apply point-in-time correctness within time windows.""" + + if len(batch) == 0: + return pd.DataFrame() + + # Group by window and join keys to apply merge_asof + result_chunks = [] + group_keys = ["time_window"] + join_keys + + for group_values, group_data in batch.groupby(group_keys): + # Separate entity and feature data + entity_data = group_data[group_data["_data_source"] == "entity"].copy() + feature_data = group_data[group_data["_data_source"] == "feature"].copy() + + if len(entity_data) > 0 and len(feature_data) > 0: + # Drop helper columns for merge_asof + entity_clean = entity_data.drop(columns=["time_window", "_data_source"]) + feature_clean = feature_data.drop( + columns=["time_window", "_data_source"] + ) + + # Apply merge_asof within the group + if join_keys: + merged = pd.merge_asof( + entity_clean.sort_values(join_keys + [timestamp_field]), + feature_clean.sort_values(join_keys + [timestamp_field]), + on=timestamp_field, + by=join_keys, + direction="backward", + ) + else: + merged = pd.merge_asof( + entity_clean.sort_values(timestamp_field), + feature_clean.sort_values(timestamp_field), + on=timestamp_field, + direction="backward", + ) + + result_chunks.append(merged) + elif len(entity_data) > 0: + # No features found, return entity data with NaN features + entity_clean = entity_data.drop(columns=["time_window", "_data_source"]) + for feat in requested_feats: + if feat not in entity_clean.columns: + entity_clean[feat] = np.nan + result_chunks.append(entity_clean) + + if result_chunks: + result = pd.concat(result_chunks, ignore_index=True) + + # Handle feature renaming if full_feature_names is True + if full_feature_names and feature_view_name: + for feat in requested_feats: + if feat in result.columns: + new_name = f"{feature_view_name}__{feat}" + result[new_name] = result[feat] + result = result.drop(columns=[feat]) + + return result + else: + return pd.DataFrame() + class RayRetrievalJob(RetrievalJob): def __init__( self, - dataset_or_callable: Union[Dataset, Callable[[], Dataset]], + dataset_or_callable: Union[ + Dataset, pd.DataFrame, Callable[[], Union[Dataset, pd.DataFrame]] + ], staging_location: Optional[str] = None, ): self._dataset_or_callable = dataset_or_callable self._staging_location = staging_location + self._cached_df: Optional[pd.DataFrame] = None self._cached_dataset: Optional[Dataset] = None self._metadata: Optional[RetrievalMetadata] = None self._full_feature_names: bool = False self._on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None - def _resolve(self) -> Any: + def _resolve(self) -> Union[Dataset, pd.DataFrame]: if callable(self._dataset_or_callable): result = self._dataset_or_callable() else: @@ -59,21 +603,60 @@ def to_df( validation_reference: Optional[ValidationReference] = None, timeout: Optional[int] = None, ) -> pd.DataFrame: + # Use cached DataFrame if available for repeated access + if self._cached_df is not None and not self.on_demand_feature_views: + return self._cached_df + + # If we have on-demand feature views, use the parent's implementation + # which calls to_arrow and applies the transformations + if self.on_demand_feature_views: + logger.info( + f"Using parent implementation for {len(self.on_demand_feature_views)} ODFVs" + ) + return super().to_df( + validation_reference=validation_reference, timeout=timeout + ) + result = self._resolve() if isinstance(result, pd.DataFrame): + self._cached_df = result return result - return result.to_pandas() + + # Convert Ray Dataset to DataFrame with progress logging + logger.info("Converting Ray dataset to DataFrame...") + self._cached_df = result.to_pandas() + logger.info(f"Converted dataset to DataFrame: {self._cached_df.shape}") + return self._cached_df def to_arrow( self, validation_reference: Optional[ValidationReference] = None, timeout: Optional[int] = None, ) -> pa.Table: + # If we have ODFVs, use the parent's implementation + if self.on_demand_feature_views: + logger.debug( + f"Using parent implementation for {len(self.on_demand_feature_views)} ODFVs" + ) + return super().to_arrow( + validation_reference=validation_reference, timeout=timeout + ) + + # For non-ODFV cases, use direct conversion result = self._resolve() if isinstance(result, pd.DataFrame): return pa.Table.from_pandas(result) - # For Ray Dataset, convert to pandas first then to arrow - return pa.Table.from_pandas(result.to_pandas()) + + # For Ray Dataset, use direct Arrow conversion if available + try: + if hasattr(result, "to_arrow"): + return result.to_arrow() + else: + # Fallback to pandas conversion + return pa.Table.from_pandas(result.to_pandas()) + except Exception: + # Fallback to pandas conversion + return pa.Table.from_pandas(result.to_pandas()) def to_remote_storage(self) -> list[str]: if not self._staging_location: @@ -109,9 +692,22 @@ def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table: result = self._resolve() if isinstance(result, pd.DataFrame): + logger.debug(f"_to_arrow_internal: DataFrame shape: {result.shape}") + logger.debug( + f"_to_arrow_internal: DataFrame columns: {list(result.columns)}" + ) return pa.Table.from_pandas(result) + # For Ray Dataset, convert to pandas first then to arrow - return pa.Table.from_pandas(result.to_pandas()) + logger.debug( + "_to_arrow_internal: Converting Ray Dataset to pandas then to arrow" + ) + df = result.to_pandas() + logger.debug(f"_to_arrow_internal: Converted dataset shape: {df.shape}") + logger.debug( + f"_to_arrow_internal: Converted dataset columns: {list(df.columns)}" + ) + return pa.Table.from_pandas(df) def persist( self, @@ -139,19 +735,12 @@ def persist( raise RuntimeError(f"Failed to persist dataset to {destination_path}: {e}") -class RayOfflineStoreConfig(FeastConfigBaseModel): - type: Literal[ - "feast.offline_stores.contrib.ray_offline_store.ray.RayOfflineStore", "ray" - ] = "ray" - storage_path: Optional[str] = None - ray_address: Optional[str] = None - use_ray_cluster: Optional[bool] = False - - class RayOfflineStore(OfflineStore): def __init__(self): self._staging_location: Optional[str] = None self._ray_initialized: bool = False + self._resource_manager: Optional[RayResourceManager] = None + self._data_processor: Optional[RayDataProcessor] = None @staticmethod def _ensure_ray_initialized(config: Optional[RepoConfig] = None): @@ -179,7 +768,7 @@ def _ensure_ray_initialized(config: Optional[RepoConfig] = None): ray.init(ignore_reinit_error=True) ctx = DatasetContext.get_current() - ctx.shuffle_strategy = "sort" + ctx.shuffle_strategy = "sort" # type: ignore ctx.enable_tensor_extension_casting = False def _init_ray(self, config: RepoConfig): @@ -188,6 +777,14 @@ def _init_ray(self, config: RepoConfig): self._ensure_ray_initialized(config) + # Initialize optimization components + if self._resource_manager is None: + self._resource_manager = RayResourceManager(ray_config) + self._resource_manager.configure_ray_context() + + if self._data_processor is None: + self._data_processor = RayDataProcessor(self._resource_manager) + def _get_source_path(self, source: DataSource, config: RepoConfig) -> str: if not isinstance(source, FileSource): raise ValueError("RayOfflineStore currently only supports FileSource") @@ -250,22 +847,31 @@ def get_historical_features( store = RayOfflineStore() store._init_ray(config) - # Load entity_df - original_entity_df = ( - pd.read_csv(entity_df) if isinstance(entity_df, str) else entity_df.copy() - ) - result_df = make_df_tzaware(original_entity_df.copy()) - if "event_timestamp" in result_df.columns: - result_df["event_timestamp"] = pd.to_datetime( - result_df["event_timestamp"], utc=True, errors="coerce" - ).dt.floor("s") + # Load entity_df as Ray dataset for distributed processing + if isinstance(entity_df, str): + entity_ds = ray.data.read_csv(entity_df) + original_entity_df = pd.read_csv(entity_df) + else: + entity_ds = ray.data.from_pandas(entity_df) + original_entity_df = entity_df.copy() + + # Make entity dataframe timezone aware + original_entity_df = make_df_tzaware(original_entity_df) + if "event_timestamp" in original_entity_df.columns: + original_entity_df["event_timestamp"] = ( + pd.to_datetime( + original_entity_df["event_timestamp"], utc=True, errors="coerce" + ) + .dt.floor("s") + .astype("datetime64[ns, UTC]") + ) # Parse feature_refs and get ODFVs on_demand_feature_views = OnDemandFeatureView.get_requested_odfvs( feature_refs, project, registry ) - # --- Request Data Validation for ODFVs --- + # Validate request data for ODFVs for odfv in on_demand_feature_views: odfv_request_data_schema = odfv.get_request_data_schema() for feature_name in odfv_request_data_schema.keys(): @@ -275,199 +881,261 @@ def get_historical_features( feature_view_name=odfv.name, ) - # Collect all join keys from feature views - all_join_keys = get_expected_join_keys(project, feature_views, registry) - if "event_timestamp" in result_df.columns: - all_join_keys.add("event_timestamp") - - # Keep only relevant entity columns and timestamp - result_df = result_df[ - [col for col in result_df.columns if col in all_join_keys] + # Filter out on-demand feature views from regular feature views + # ODFVs don't have data sources and are computed from base features + odfv_names = {odfv.name for odfv in on_demand_feature_views} + regular_feature_views = [ + fv for fv in feature_views if fv.name not in odfv_names ] - requested_feature_columns = [] - added_dummy_columns = set() + logger.info( + f"Processing {len(regular_feature_views)} regular feature views and {len(on_demand_feature_views)} on-demand feature views with {len(feature_refs)} feature references" + ) + + # Apply field mappings to entity dataset if needed + global_field_mappings = {} + for fv in regular_feature_views: + mapping = getattr(fv.batch_source, "field_mapping", None) + if mapping: + for k, v in mapping.items(): + global_field_mappings[v] = k + + if global_field_mappings: + cols_to_rename = { + v: k + for k, v in global_field_mappings.items() + if v in original_entity_df.columns + } + if cols_to_rename: + entity_ds = entity_ds.map_batches( + lambda batch: batch.rename(columns=cols_to_rename), + batch_format="pandas", + ) + + # Start with entity dataset + result_ds = entity_ds - # Join each feature view - for fv in feature_views: - # Only process feature views that are referenced + # Process each regular feature view with intelligent join strategy + for fv in regular_feature_views: fv_feature_refs = [ - ref for ref in feature_refs if ref.startswith(fv.name + ":") + ref + for ref in feature_refs + if ref.startswith(fv.projection.name_to_use() + ":") ] if not fv_feature_refs: continue - # Get join keys, feature names, timestamp, created timestamp + logger.info(f"Processing feature view: {fv.name}") + + # Get join configuration entities = fv.entities or [] entity_objs = [registry.get_entity(e, project) for e in entities] - join_keys, feature_names, timestamp_field, created_col = _get_column_names( + original_join_keys, _, timestamp_field, created_col = _get_column_names( fv, entity_objs ) - if not join_keys: - join_keys = [DUMMY_ENTITY_ID] - # Only add features that are actually requested in feature_refs + # Apply join key mapping from projection if present + if fv.projection.join_key_map: + join_keys = [ + fv.projection.join_key_map.get(key, key) + for key in original_join_keys + ] + else: + join_keys = original_join_keys + + # Extract requested features requested_feats = [ref.split(":", 1)[1] for ref in fv_feature_refs] - # --- Error for Missing Features --- + # Validate requested features exist available_feature_names = [f.name for f in fv.features] missing_feats = [ f for f in requested_feats if f not in available_feature_names ] if missing_feats: raise KeyError( - f"Requested features {missing_feats} not found in feature view '{fv.name}' (available: {available_feature_names})" + f"Requested features {missing_feats} not found in feature view '{fv.name}' " + f"(available: {available_feature_names})" ) - for feat in requested_feats: - col_name = f"{fv.name}__{feat}" if full_feature_names else feat - requested_feature_columns.append(col_name) + logger.info( + f"Feature view '{fv.name}': requesting {requested_feats}, available: {available_feature_names}" + ) - # Read feature data + # Load feature data as Ray dataset source_path = store._get_source_path(fv.batch_source, config) - if not source_path: - raise ValueError(f"Missing batch source for FV {fv.name}") - feature_ds = ray.data.read_parquet(str(source_path)) - feature_df = feature_ds.to_pandas() - feature_df = make_df_tzaware(feature_df) - if timestamp_field in feature_df.columns: - feature_df[timestamp_field] = pd.to_datetime( - feature_df[timestamp_field], utc=True, errors="coerce" - ).dt.floor("s") - - # Ensure join keys exist in both entity and feature dataframe - for k in join_keys: - if k not in result_df.columns: - result_df[k] = DUMMY_ENTITY_VAL - added_dummy_columns.add(k) - if k not in feature_df.columns: - feature_df[k] = DUMMY_ENTITY_VAL - + feature_ds = ray.data.read_parquet(source_path) + feature_size = feature_ds.size_bytes() + + # Apply field mapping to feature dataset if needed + field_mapping = getattr(fv.batch_source, "field_mapping", None) + if field_mapping: + feature_ds = feature_ds.map_batches( + lambda batch: batch.rename(columns=field_mapping), + batch_format="pandas", + ) + # Update join keys and timestamp field to mapped names + join_keys = [field_mapping.get(k, k) for k in join_keys] + timestamp_field = field_mapping.get(timestamp_field, timestamp_field) + if created_col: + created_col = field_mapping.get(created_col, created_col) + + # Apply projection join key mapping to entity dataset if needed + if fv.projection.join_key_map: + # The feature dataset keeps its original columns (e.g., location_id) + # The entity dataset gets the mapped columns (e.g., origin_id, destination_id) + # We need to ensure the entity dataset has the properly mapped columns + pass # The entity dataset already has the mapped columns in this case + + # Ensure timestamp compatibility in entity dataset if ( - timestamp_field not in result_df.columns - and "event_timestamp" in result_df.columns + timestamp_field != "event_timestamp" + and timestamp_field not in original_entity_df.columns + and "event_timestamp" in original_entity_df.columns ): - result_df[timestamp_field] = result_df["event_timestamp"] - - # Align join key dtypes before merge - for k in join_keys: - if k in result_df.columns and k in feature_df.columns: - feature_df[k] = feature_df[k].astype(result_df[k].dtype) - - # Deduplicate feature values (avoid list columns in keys) - dedup_keys = join_keys + [timestamp_field] - if created_col and created_col in feature_df.columns: - feature_df = feature_df.sort_values(by=dedup_keys + [created_col]) - feature_df = feature_df.groupby(dedup_keys, as_index=False).last() - else: - feature_df = feature_df.sort_values(by=dedup_keys) - feature_df = feature_df.drop_duplicates(subset=dedup_keys, keep="last") - - # Select only requested features that exist in feature_df - existing_feats = [f for f in requested_feats if f in feature_df.columns] - cols_to_keep = join_keys + [timestamp_field] + existing_feats - feature_df = feature_df[cols_to_keep] - - # Join into result_df - result_df = result_df.merge( - feature_df, - how="inner", - on=join_keys - + ([timestamp_field] if timestamp_field in result_df.columns else []), + + def add_timestamp_field(batch: pd.DataFrame) -> pd.DataFrame: + batch = batch.copy() + batch[timestamp_field] = ( + pd.to_datetime( + batch["event_timestamp"], utc=True, errors="coerce" + ) + .dt.floor("s") + .astype("datetime64[ns, UTC]") + ) + return batch + + result_ds = result_ds.map_batches( + add_timestamp_field, batch_format="pandas" + ) + + # Determine join strategy based on dataset sizes and cluster resources + if store._resource_manager is None: + raise ValueError("Resource manager not initialized") + requirements = store._resource_manager.estimate_processing_requirements( + feature_size, "join" ) - # Handle full feature names - if full_feature_names: - result_df = result_df.rename( - columns={ - f: f"{fv.name}__{f}" - for f in existing_feats - if f in result_df.columns - } + if requirements["should_broadcast"]: + # Use broadcast join for small feature datasets + logger.info( + f"Using broadcast join for {fv.name} (size: {feature_size // 1024**2}MB)" ) + feature_df = feature_ds.to_pandas() + feature_df = make_df_tzaware(feature_df) - # Re-attach original entity columns - for col in original_entity_df.columns: - if col not in result_df.columns: - result_df[col] = original_entity_df[col] - - # Ensure event_timestamp is present - if ( - "event_timestamp" not in result_df.columns - and "event_timestamp" in original_entity_df.columns - ): - result_df["event_timestamp"] = pd.to_datetime( - original_entity_df["event_timestamp"], utc=True, errors="coerce" - ).dt.floor("s") - - if ( - "event_timestamp" not in result_df.columns - and timestamp_field in result_df.columns - ): - result_df["event_timestamp"] = result_df[timestamp_field] - - # Drop dummy entity columns - for dummy_col in added_dummy_columns: - if dummy_col in result_df.columns: - result_df = result_df.drop(columns=[dummy_col]) - - # Reorder columns: entity + timestamp + features (in requested order) - entity_columns = [ - c for c in original_entity_df.columns if c != "event_timestamp" - ] - # Build the list of output feature columns in the correct order - output_feature_columns = [] - for ref in feature_refs: - fv_name, feat = ref.split(":", 1) - col_name = f"{fv_name}__{feat}" if full_feature_names else feat - output_feature_columns.append(col_name) - - # Ensure all requested features are present, fill with NaN if missing - for col in output_feature_columns: - if col not in result_df.columns: - result_df[col] = np.nan - - final_columns = entity_columns + ["event_timestamp"] + output_feature_columns - result_df = result_df.reindex(columns=final_columns) - - # Convert list/numpy.ndarray columns to tuples for deduplication - def make_hashable_for_dedup(df, columns): - for col in columns: - if col in df.columns: - if df[col].apply(lambda x: isinstance(x, (np.ndarray, list))).any(): - df[col] = df[col].apply( - lambda x: tuple(x) - if isinstance(x, (np.ndarray, list)) - else x + if timestamp_field in feature_df.columns: + feature_df[timestamp_field] = ( + pd.to_datetime( + feature_df[timestamp_field], utc=True, errors="coerce" ) - return df + .dt.floor("s") + .astype("datetime64[ns, UTC]") + ) - list_columns = [ - col - for col in final_columns - if col in result_df.columns - and result_df[col].apply(lambda x: isinstance(x, (np.ndarray, list))).any() - ] - result_df = make_hashable_for_dedup(result_df, list_columns) + if store._data_processor is None: + raise ValueError("Data processor not initialized") + result_ds = store._data_processor.broadcast_join_features( + result_ds, + feature_df, + join_keys, + timestamp_field, + requested_feats, + full_feature_names, + fv.projection.name_to_use(), + original_join_keys, + ) + else: + # Use distributed windowed join for large feature datasets + logger.info( + f"Using distributed join for {fv.name} (size: {feature_size // 1024**2}MB)" + ) + + # Ensure timestamp format in feature dataset + def normalize_timestamps(batch: pd.DataFrame) -> pd.DataFrame: + batch = make_df_tzaware(batch) + if timestamp_field in batch.columns: + batch[timestamp_field] = ( + pd.to_datetime( + batch[timestamp_field], utc=True, errors="coerce" + ) + .dt.floor("s") + .astype("datetime64[ns, UTC]") + ) + return batch - # Deduplicate - result_df = result_df.drop_duplicates().reset_index(drop=True) + feature_ds = feature_ds.map_batches( + normalize_timestamps, batch_format="pandas" + ) - # Convert tuple columns back to lists - for col in list_columns: - if col in result_df.columns: - result_df[col] = result_df[col].apply( - lambda x: list(x) if isinstance(x, tuple) else x + if store._data_processor is None: + raise ValueError("Data processor not initialized") + result_ds = store._data_processor.windowed_temporal_join( + result_ds, + feature_ds, + join_keys, + timestamp_field, + requested_feats, + window_size=config.offline_store.window_size_for_joins, + full_feature_names=full_feature_names, + feature_view_name=fv.projection.name_to_use(), + original_join_keys=original_join_keys, ) - # Return retrieval job + # Final processing: clean up and ensure proper column structure + def finalize_result(batch: pd.DataFrame) -> pd.DataFrame: + logger.debug(f"Finalizing result - input columns: {list(batch.columns)}") + logger.debug(f"Finalizing result - batch shape: {batch.shape}") + + batch = batch.copy() + + # Preserve existing feature columns (including renamed ones) + existing_columns = set(batch.columns) + + # Re-attach any missing original entity columns that aren't already present + for col in original_entity_df.columns: + if col not in existing_columns: + # For missing columns, use values from original entity df + if len(batch) <= len(original_entity_df): + batch[col] = original_entity_df[col].iloc[: len(batch)].values + else: + # Repeat values if batch is larger + repeated_values = np.tile( + original_entity_df[col].values, + (len(batch) // len(original_entity_df) + 1), + ) + batch[col] = repeated_values[: len(batch)] + + # Ensure event_timestamp is present + if "event_timestamp" not in batch.columns: + if "event_timestamp" in original_entity_df.columns: + batch["event_timestamp"] = ( + pd.to_datetime( + original_entity_df["event_timestamp"].iloc[: len(batch)], + utc=True, + errors="coerce", + ) + .dt.floor("s") + .astype("datetime64[ns, UTC]") + ) + elif timestamp_field in batch.columns: + batch["event_timestamp"] = batch[timestamp_field] + + logger.debug(f"Final columns: {list(batch.columns)}") + return batch + + result_ds = result_ds.map_batches(finalize_result, batch_format="pandas") + + # Storage path validation storage_path = config.offline_store.storage_path if not storage_path: raise ValueError("Storage path must be set in config") - job = RayRetrievalJob(result_df, staging_location=storage_path) + # Create retrieval job following standard pattern + job = RayRetrievalJob(result_ds, staging_location=storage_path) job._full_feature_names = full_feature_names job._on_demand_feature_views = on_demand_feature_views + + logger.info("Historical features processing completed successfully") return job def validate_data_source( @@ -509,9 +1177,128 @@ def pull_latest_from_table_or_query( def _load(): try: - return RayOfflineStore._create_filtered_dataset( + # Load and filter the dataset + ds = RayOfflineStore._create_filtered_dataset( source_path, timestamp_field, start_date, end_date ) + + # Convert to pandas for deduplication and column selection + df = ds.to_pandas() + df = make_df_tzaware(df) + + # Apply field mapping if needed + field_mapping = getattr(data_source, "field_mapping", None) + if field_mapping: + df = df.rename(columns=field_mapping) + + # Use the actual timestamp field name (this is already the correct mapped name) + timestamp_field_mapped = timestamp_field + created_timestamp_column_mapped = created_timestamp_column + + # Handle empty DataFrame case + if df.empty: + logger.info( + "DataFrame is empty after filtering, creating empty DataFrame with required columns" + ) + # Create an empty DataFrame with the required columns + empty_columns = ( + join_key_columns + + feature_name_columns + + [timestamp_field_mapped] + ) + if created_timestamp_column_mapped: + empty_columns.append(created_timestamp_column_mapped) + if not join_key_columns: + empty_columns.append(DUMMY_ENTITY_ID) + + # Add event_timestamp column for pandas backend compatibility + if "event_timestamp" not in empty_columns: + empty_columns.append("event_timestamp") + + # Create empty DataFrame with proper column types + empty_df = pd.DataFrame(columns=empty_columns) + return empty_df + + # Ensure timestamp is properly formatted + if timestamp_field_mapped in df.columns: + df[timestamp_field_mapped] = ( + pd.to_datetime( + df[timestamp_field_mapped], utc=True, errors="coerce" + ) + .dt.floor("s") + .astype("datetime64[ns, UTC]") + ) + + if ( + created_timestamp_column_mapped + and created_timestamp_column_mapped in df.columns + ): + df[created_timestamp_column_mapped] = ( + pd.to_datetime( + df[created_timestamp_column_mapped], + utc=True, + errors="coerce", + ) + .dt.floor("s") + .astype("datetime64[ns, UTC]") + ) + + # Prepare columns to select + timestamp_columns = [timestamp_field_mapped] + if created_timestamp_column_mapped: + timestamp_columns.append(created_timestamp_column_mapped) + + all_required_columns = ( + join_key_columns + feature_name_columns + timestamp_columns + ) + + # Select only the required columns that exist + available_columns = [ + col for col in all_required_columns if col in df.columns + ] + df = df[available_columns] + + # Handle deduplication (keep latest records) + if join_key_columns: + # Sort by timestamp columns (latest first) and deduplicate by join keys + # Filter out timestamp columns that don't exist in the dataframe + existing_timestamp_columns = [ + col for col in timestamp_columns if col in df.columns + ] + sort_columns = join_key_columns + existing_timestamp_columns + if sort_columns: + df = df.sort_values( + sort_columns, + ascending=[True] * len(join_key_columns) + + [False] * len(existing_timestamp_columns), + ) + df = df.drop_duplicates(subset=join_key_columns, keep="first") + else: + # No join keys - add dummy entity and sort by timestamp + df[DUMMY_ENTITY_ID] = DUMMY_ENTITY_VAL + # Filter out timestamp columns that don't exist in the dataframe + existing_timestamp_columns = [ + col for col in timestamp_columns if col in df.columns + ] + if existing_timestamp_columns: + df = df.sort_values(existing_timestamp_columns, ascending=False) + + # Reset index + df = df.reset_index(drop=True) + + # Ensure 'event_timestamp' column exists for pandas backend compatibility + if ( + "event_timestamp" not in df.columns + and timestamp_field_mapped != "event_timestamp" + ): + if timestamp_field_mapped in df.columns: + df["event_timestamp"] = df[timestamp_field_mapped] + logger.debug( + f"Added 'event_timestamp' column from '{timestamp_field_mapped}' for pandas backend compatibility" + ) + + return df + except Exception as e: raise RuntimeError(f"Failed to load data from {source_path}: {e}") @@ -541,9 +1328,124 @@ def pull_all_from_table_or_query( def _load(): try: - return RayOfflineStore._create_filtered_dataset( + # Load and filter the dataset + ds = RayOfflineStore._create_filtered_dataset( source_path, timestamp_field, start_date, end_date ) + + # Convert to pandas for column selection + df = ds.to_pandas() + df = make_df_tzaware(df) + + # Apply field mapping if needed + field_mapping = getattr(data_source, "field_mapping", None) + if field_mapping: + df = df.rename(columns=field_mapping) + + # Use the actual timestamp field name (this is already the correct mapped name) + timestamp_field_mapped = timestamp_field + created_timestamp_column_mapped = created_timestamp_column + + # Debug logging + logger.debug(f"DataFrame columns: {df.columns.tolist()}") + logger.debug(f"Timestamp field: {timestamp_field_mapped}") + logger.debug( + f"Created timestamp column: {created_timestamp_column_mapped}" + ) + logger.debug(f"DataFrame shape: {df.shape}") + + # Handle empty DataFrame case + if df.empty: + logger.info( + "DataFrame is empty after filtering, creating empty DataFrame with required columns" + ) + # Create an empty DataFrame with the required columns + empty_columns = ( + join_key_columns + + feature_name_columns + + [timestamp_field_mapped] + ) + if created_timestamp_column_mapped: + empty_columns.append(created_timestamp_column_mapped) + if not join_key_columns: + empty_columns.append(DUMMY_ENTITY_ID) + + # Add event_timestamp column for pandas backend compatibility + if "event_timestamp" not in empty_columns: + empty_columns.append("event_timestamp") + + # Create empty DataFrame with proper column types + empty_df = pd.DataFrame(columns=empty_columns) + return empty_df + + # Ensure timestamp is properly formatted + if timestamp_field_mapped in df.columns: + df[timestamp_field_mapped] = ( + pd.to_datetime( + df[timestamp_field_mapped], utc=True, errors="coerce" + ) + .dt.floor("s") + .astype("datetime64[ns, UTC]") + ) + + if ( + created_timestamp_column_mapped + and created_timestamp_column_mapped in df.columns + ): + df[created_timestamp_column_mapped] = ( + pd.to_datetime( + df[created_timestamp_column_mapped], + utc=True, + errors="coerce", + ) + .dt.floor("s") + .astype("datetime64[ns, UTC]") + ) + + # Prepare columns to select + timestamp_columns = [timestamp_field_mapped] + if created_timestamp_column_mapped: + timestamp_columns.append(created_timestamp_column_mapped) + + all_required_columns = ( + join_key_columns + feature_name_columns + timestamp_columns + ) + + # Add dummy entity if no join keys + if not join_key_columns: + df[DUMMY_ENTITY_ID] = DUMMY_ENTITY_VAL + all_required_columns.append(DUMMY_ENTITY_ID) + + # Select only the required columns that exist + available_columns = [ + col for col in all_required_columns if col in df.columns + ] + df = df[available_columns] + + # Sort by timestamp (most recent first) + # Filter out timestamp columns that don't exist in the dataframe + existing_timestamp_columns = [ + col for col in timestamp_columns if col in df.columns + ] + if existing_timestamp_columns: + df = df.sort_values(existing_timestamp_columns, ascending=False) + + # Reset index + df = df.reset_index(drop=True) + + # Ensure 'event_timestamp' column exists for pandas backend compatibility + if ( + "event_timestamp" not in df.columns + and timestamp_field_mapped != "event_timestamp" + ): + if timestamp_field_mapped in df.columns: + df["event_timestamp"] = df[timestamp_field_mapped] + logger.debug( + f"Added 'event_timestamp' column from '{timestamp_field_mapped}' for pandas backend compatibility" + ) + + return df + except Exception as e: raise RuntimeError(f"Failed to load data from {source_path}: {e}") diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_repo_configuration.py b/sdk/python/feast/infra/offline_stores/contrib/ray_repo_configuration.py index ea6cdbaa3c8..43628e7ea1a 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ray_repo_configuration.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_repo_configuration.py @@ -2,13 +2,12 @@ import tempfile from typing import Any, Dict, Optional -from sdk.python.feast.infra.offline_stores.contrib.ray_offline_store.ray import ( - RayOfflineStoreConfig, -) - from feast.data_format import ParquetFormat from feast.data_source import DataSource from feast.feature_logging import LoggingDestination +from feast.infra.offline_stores.contrib.ray_offline_store.ray import ( + RayOfflineStoreConfig, +) from feast.infra.offline_stores.file_source import ( FileLoggingDestination, FileSource, @@ -33,8 +32,8 @@ def __init__(self, project_name: str, *args, **kwargs): ray_address=None, use_ray_cluster=False, ) - self.files = [] - self.dirs = [] + self.files: list[Any] = [] + self.dirs: list[str] = [] def create_offline_store_config(self) -> FeastConfigBaseModel: return self.offline_store_config From 5b629b4620cf4c5d48c4f0c48ad997761cd265f6 Mon Sep 17 00:00:00 2001 From: ntkathole Date: Wed, 9 Jul 2025 11:18:53 +0530 Subject: [PATCH 04/10] Fixed tests Signed-off-by: ntkathole --- .../contrib/ray_offline_store/ray.py | 958 +++++++++++------- 1 file changed, 616 insertions(+), 342 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py index 39a43c7c001..e7109be8e31 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py @@ -31,12 +31,118 @@ from feast.on_demand_feature_view import OnDemandFeatureView from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage, ValidationReference -from feast.utils import _get_column_names, _utc_now, make_df_tzaware +from feast.type_map import feast_value_type_to_pandas_type +from feast.utils import _get_column_names, make_df_tzaware +from feast.value_type import ValueType logger = logging.getLogger(__name__) +def _convert_feature_column_types( + batch: pd.DataFrame, feature_views: List[FeatureView] +) -> pd.DataFrame: + """ + Convert feature columns to appropriate pandas types using Feast's type mapping utilities. + + Args: + batch: DataFrame containing feature data + feature_views: List of feature views with type information + + Returns: + DataFrame with properly converted feature column types + """ + batch = batch.copy() + + for fv in feature_views: + for feature in fv.features: + feat_name = feature.name + + # Check if this feature exists in the batch + if feat_name not in batch.columns: + continue + + try: + # Get the Feast ValueType for this feature + value_type = feature.dtype.to_value_type() + + # Handle array/list types + if value_type.name.endswith("_LIST"): + batch[feat_name] = _convert_array_column( + batch[feat_name], value_type + ) + else: + # Handle scalar types using feast type mapping + target_pandas_type = feast_value_type_to_pandas_type(value_type) + batch[feat_name] = _convert_scalar_column( + batch[feat_name], value_type, target_pandas_type + ) + + except Exception as e: + logger.warning( + f"Failed to convert feature {feat_name} to proper type: {e}" + ) + # Keep original dtype if conversion fails + continue + + return batch + + +def _convert_scalar_column( + series: pd.Series, value_type: ValueType, target_pandas_type: str +) -> pd.Series: + """Convert a scalar feature column to the appropriate pandas type.""" + if value_type == ValueType.INT32: + return pd.to_numeric(series, errors="coerce").astype("Int32") + elif value_type == ValueType.INT64: + return pd.to_numeric(series, errors="coerce").astype("Int64") + elif value_type in [ValueType.FLOAT, ValueType.DOUBLE]: + return pd.to_numeric(series, errors="coerce").astype("float64") + elif value_type == ValueType.BOOL: + return series.astype("boolean") + elif value_type == ValueType.STRING: + return series.astype("string") + elif value_type == ValueType.UNIX_TIMESTAMP: + return pd.to_datetime(series, unit="s", errors="coerce") + else: + # For other types, use pandas default conversion + return series.astype(target_pandas_type) + + +def _convert_array_column(series: pd.Series, value_type: ValueType) -> pd.Series: + """Convert an array feature column to the appropriate type with proper empty array handling.""" + # Determine the base type for array elements + base_type_map = { + ValueType.INT32_LIST: np.int32, + ValueType.INT64_LIST: np.int64, + ValueType.FLOAT_LIST: np.float32, + ValueType.DOUBLE_LIST: np.float64, + ValueType.BOOL_LIST: np.bool_, + ValueType.STRING_LIST: object, + ValueType.BYTES_LIST: object, + ValueType.UNIX_TIMESTAMP_LIST: "datetime64[s]", + } + + target_dtype = base_type_map.get(value_type, object) + + def convert_array_item(item): + if item is None or (isinstance(item, list) and len(item) == 0): + # Return properly typed empty array + if target_dtype == object: + return np.array([], dtype=object) + else: + return np.array([], dtype=target_dtype) + else: + # Return the item as-is for non-empty arrays + return item + + return series.apply(convert_array_item) + + class RayOfflineStoreConfig(FeastConfigBaseModel): + """ + Configuration for the Ray Offline Store. + """ + type: Literal[ "feast.offline_stores.contrib.ray_offline_store.ray.RayOfflineStore", "ray" ] = "ray" @@ -53,61 +159,63 @@ class RayOfflineStoreConfig(FeastConfigBaseModel): class RayResourceManager: - """Manages Ray cluster resources for optimal performance.""" + """ + Manages Ray cluster resources for optimal performance. + """ def __init__(self, config: Optional[RayOfflineStoreConfig] = None): + """ + Initialize the resource manager with cluster resource information. + """ self.config = config or RayOfflineStoreConfig() self.cluster_resources = ray.cluster_resources() - self.available_memory = self.cluster_resources.get( - "memory", 8 * 1024**3 - ) # 8GB default + self.available_memory = self.cluster_resources.get("memory", 8 * 1024**3) self.available_cpus = int(self.cluster_resources.get("CPU", 4)) self.num_nodes = len(ray.nodes()) if ray.is_initialized() else 1 - def configure_ray_context(self): - """Configure Ray DatasetContext for optimal performance.""" + def configure_ray_context(self) -> None: + """ + Configure Ray DatasetContext for optimal performance based on available resources. + """ ctx = DatasetContext.get_current() - # Set buffer sizes based on available memory - if self.available_memory > 32 * 1024**3: # 32GB - ctx.target_shuffle_buffer_size = 2 * 1024**3 # 2GB - ctx.target_max_block_size = 512 * 1024**2 # 512MB + if self.available_memory > 32 * 1024**3: + ctx.target_shuffle_buffer_size = 2 * 1024**3 + ctx.target_max_block_size = 512 * 1024**2 else: - ctx.target_shuffle_buffer_size = 512 * 1024**2 # 512MB - ctx.target_max_block_size = 128 * 1024**2 # 128MB - - # Configure parallelism + ctx.target_shuffle_buffer_size = 512 * 1024**2 + ctx.target_max_block_size = 128 * 1024**2 ctx.min_parallelism = self.available_cpus - ctx.max_parallelism = ( - self.available_cpus * self.config.max_parallelism_multiplier + multiplier = ( + self.config.max_parallelism_multiplier + if self.config.max_parallelism_multiplier is not None + else 2 ) - - # Optimize for feature store workloads - ctx.shuffle_strategy = "sort" + ctx.max_parallelism = self.available_cpus * multiplier + ctx.shuffle_strategy = "sort" # type: ignore ctx.enable_tensor_extension_casting = False - logger.info( f"Configured Ray context: {self.available_cpus} CPUs, " f"{self.available_memory // 1024**3}GB memory, {self.num_nodes} nodes" ) def estimate_optimal_partitions(self, dataset_size_bytes: int) -> int: - """Estimate optimal number of partitions for a dataset.""" - # Use configured target partition size + """ + Estimate optimal number of partitions for a dataset based on size and resources. + """ target_partition_size = (self.config.target_partition_size_mb or 64) * 1024**2 size_based_partitions = max(1, dataset_size_bytes // target_partition_size) - - # Don't exceed configured max parallelism max_partitions = self.available_cpus * ( self.config.max_parallelism_multiplier or 2 ) - return min(size_based_partitions, max_partitions) def should_use_broadcast_join( self, dataset_size_bytes: int, threshold_mb: Optional[int] = None ) -> bool: - """Determine if dataset is small enough for broadcast join.""" + """ + Determine if dataset is small enough for broadcast join. + """ threshold = ( threshold_mb if threshold_mb is not None @@ -118,20 +226,18 @@ def should_use_broadcast_join( def estimate_processing_requirements( self, dataset_size_bytes: int, operation_type: str ) -> Dict[str, Any]: - """Estimate resource requirements for different operations.""" - - # Memory requirements (with safety margin) + """ + Estimate resource requirements for different operations. + """ memory_multiplier = { "read": 1.2, # 20% overhead for reading "join": 3.0, # 3x for join operations "aggregate": 2.0, # 2x for aggregations "shuffle": 2.5, # 2.5x for shuffling } - required_memory = dataset_size_bytes * memory_multiplier.get( operation_type, 2.0 ) - return { "required_memory": required_memory, "optimal_partitions": self.estimate_optimal_partitions(dataset_size_bytes), @@ -141,27 +247,136 @@ def estimate_processing_requirements( class RayDataProcessor: - """Optimized data processing with Ray for feature store operations.""" + """ + Optimized data processing with Ray for feature store operations. + """ def __init__(self, resource_manager: RayResourceManager): + """ + Initialize the data processor with a resource manager. + """ self.resource_manager = resource_manager def optimize_dataset_for_join(self, ds: Dataset, join_keys: List[str]) -> Dataset: - """Optimize dataset partitioning for join operations.""" - - # Estimate optimal partitions + """ + Optimize dataset partitioning for join operations. + """ dataset_size = ds.size_bytes() optimal_partitions = self.resource_manager.estimate_optimal_partitions( dataset_size ) - if not join_keys: # For datasets without join keys, use simple repartitioning return ds.repartition(num_blocks=optimal_partitions) - # For datasets with join keys, use shuffle for better distribution return ds.random_shuffle(num_blocks=optimal_partitions) + def _manual_point_in_time_join( + self, + batch_df: pd.DataFrame, + features_df: pd.DataFrame, + join_keys: List[str], + feature_join_keys: List[str], + timestamp_field: str, + requested_feats: List[str], + ) -> pd.DataFrame: + """ + Perform manual point-in-time join when merge_asof fails. + + This method handles cases where merge_asof cannot be used due to: + - Entity mapping (different column names) + - Complex multi-entity joins + - Sorting issues with the data + """ + result = batch_df.copy() + for feat in requested_feats: + is_list_feature = False + if feat in features_df.columns: + sample_values = features_df[feat].dropna() + if not sample_values.empty: + sample_value = sample_values.iloc[0] + if isinstance(sample_value, (list, np.ndarray)): + is_list_feature = True + elif ( + features_df[feat].dtype == object + and sample_values.apply( + lambda x: isinstance(x, (list, np.ndarray)) + ).any() + ): + is_list_feature = True + + if is_list_feature: + result[feat] = [[] for _ in range(len(result))] + else: + # Check if the feature column is datetime + if feat in features_df.columns and pd.api.types.is_datetime64_any_dtype( + features_df[feat] + ): + result[feat] = pd.Series( + [pd.NaT] * len(result), dtype="datetime64[ns, UTC]" + ) + else: + result[feat] = np.nan + + for _, entity_row in batch_df.iterrows(): + entity_matches = pd.Series( + [True] * len(features_df), index=features_df.index + ) + for entity_key, feature_key in zip(join_keys, feature_join_keys): + if entity_key in entity_row and feature_key in features_df.columns: + entity_value = entity_row[entity_key] + feature_column = features_df[feature_key] + if pd.api.types.is_scalar(entity_value): + entity_matches &= feature_column == entity_value + else: + if hasattr(entity_value, "__len__") and len(entity_value) > 0: + entity_matches &= feature_column.isin(entity_value) + else: + entity_matches &= pd.Series( + [False] * len(features_df), index=features_df.index + ) + + if not entity_matches.any(): + continue + + matching_features = features_df[entity_matches] + + if matching_features.empty: + continue + + entity_timestamp = entity_row[timestamp_field] + if timestamp_field in matching_features.columns: + time_matches = matching_features[timestamp_field] <= entity_timestamp + matching_features = matching_features[time_matches] + + if matching_features.empty: + continue + + if timestamp_field in matching_features.columns: + matching_features = matching_features.sort_values(timestamp_field) + latest_feature = matching_features.iloc[-1] + else: + latest_feature = matching_features.iloc[-1] + + entity_index = entity_row.name + for feat in requested_feats: + if feat in latest_feature: + feature_value = latest_feature[feat] + if pd.api.types.is_scalar(feature_value): + if pd.notna(feature_value): + result.loc[entity_index, feat] = feature_value + elif isinstance(feature_value, (list, tuple, np.ndarray)): + result.at[entity_index, feat] = feature_value + else: + try: + if pd.notna(feature_value): + result.at[entity_index, feat] = feature_value + except (ValueError, TypeError): + if feature_value is not None: + result.at[entity_index, feat] = feature_value + + return result + def broadcast_join_features( self, entity_ds: Dataset, @@ -182,34 +397,46 @@ def join_batch_with_features(batch: pd.DataFrame) -> pd.DataFrame: """Join a batch with broadcast feature data.""" features = ray.get(feature_ref) - logger.debug(f"Broadcast join - Features DataFrame shape: {features.shape}") - logger.debug( - f"Broadcast join - Features DataFrame columns: {list(features.columns)}" - ) - logger.debug(f"Broadcast join - Requested features: {requested_feats}") - logger.debug(f"Broadcast join - Join keys: {join_keys}") - logger.debug(f"Broadcast join - Timestamp field: {timestamp_field}") - logger.debug(f"Broadcast join - Batch DataFrame shape: {batch.shape}") - logger.debug( - f"Broadcast join - Batch DataFrame columns: {list(batch.columns)}" + logger.info( + f"Processing feature view {feature_view_name} with join keys {join_keys}" ) - if feature_view_name: - logger.info( - f"Processing feature view {feature_view_name} with join keys {join_keys}" - ) + + # Determine feature join keys + # For entity mapping (join key mapping), original_join_keys contains the original feature view join keys + # join_keys contains the mapped entity join keys + if original_join_keys: + # Entity mapping case: entity has join_keys, features have original_join_keys + feature_join_keys = original_join_keys + entity_join_keys = join_keys + else: + # Normal case: both use the same join keys + feature_join_keys = join_keys + entity_join_keys = join_keys # Select only required feature columns plus join keys and timestamp - # Use original join keys for filtering if provided (for entity mapping) - filter_join_keys = original_join_keys if original_join_keys else join_keys - feature_cols = [timestamp_field] + filter_join_keys + requested_feats - features_filtered = features[feature_cols].copy() + feature_cols = [timestamp_field] + feature_join_keys + requested_feats - logger.debug( - f"Broadcast join - Features filtered shape: {features_filtered.shape}" - ) - logger.debug( - f"Broadcast join - Features filtered columns: {list(features_filtered.columns)}" - ) + # Only include columns that actually exist in the features DataFrame + available_feature_cols = [ + col for col in feature_cols if col in features.columns + ] + + # Ensure we have the minimum required columns + if timestamp_field not in available_feature_cols: + raise ValueError( + f"Timestamp field '{timestamp_field}' not found in features columns: {list(features.columns)}" + ) + + # Check if required feature columns exist + missing_feats = [ + feat for feat in requested_feats if feat not in features.columns + ] + if missing_feats: + raise ValueError( + f"Requested features {missing_feats} not found in features columns: {list(features.columns)}" + ) + + features_filtered = features[available_feature_cols].copy() # Ensure timestamp columns have compatible dtypes and precision if timestamp_field in batch.columns: @@ -228,7 +455,7 @@ def join_batch_with_features(batch: pd.DataFrame) -> pd.DataFrame: .astype("datetime64[ns, UTC]") ) - if not join_keys: + if not entity_join_keys: # Temporal join without entity keys batch_sorted = batch.sort_values(timestamp_field).reset_index(drop=True) features_sorted = features_filtered.sort_values( @@ -241,22 +468,20 @@ def join_batch_with_features(batch: pd.DataFrame) -> pd.DataFrame: direction="backward", ) else: - # Temporal join with entity keys - # Clean data first by removing NaN values - # Use original join keys for filtering feature dataset, mapped join keys for entity dataset - feature_join_keys = ( - original_join_keys if original_join_keys else join_keys - ) - - for key in join_keys: + # Ensure entity join keys exist in batch + for key in entity_join_keys: if key not in batch.columns: - batch[key] = None + batch[key] = np.nan + + # Ensure feature join keys exist in features for key in feature_join_keys: if key not in features_filtered.columns: - features_filtered[key] = None + features_filtered[key] = np.nan # Drop rows with NaN values in join keys or timestamp - batch_clean = batch.dropna(subset=join_keys + [timestamp_field]).copy() + batch_clean = batch.dropna( + subset=entity_join_keys + [timestamp_field] + ).copy() features_clean = features_filtered.dropna( subset=feature_join_keys + [timestamp_field] ).copy() @@ -265,164 +490,136 @@ def join_batch_with_features(batch: pd.DataFrame) -> pd.DataFrame: if batch_clean.empty or features_clean.empty: return batch.head(0) # Return empty dataframe with same columns - # Important: For merge_asof with 'by' parameter, sort by 'by' columns first, then by 'on' column - # Both DataFrames must be sorted identically - batch_sort_columns = join_keys + [timestamp_field] - features_sort_columns = feature_join_keys + [timestamp_field] + # Sort both DataFrames for merge_asof requirements + # merge_asof requires: left sorted by 'on' column, right sorted by ['by'] + ['on'] columns - try: - # For entity mapping, we need to manually join since merge_asof doesn't support different column names - if original_join_keys and original_join_keys != join_keys: - # Manual join for entity mapping - logger.info("Using manual join for entity mapping") - raise ValueError("Entity mapping requires manual join") - - # For multi-entity joins, use manual join to ensure correctness - if len(join_keys) > 1: - logger.info( - f"Using manual join for multi-entity join with keys: {join_keys}" - ) - raise ValueError("Multi-entity join requires manual join") - - # Sort both DataFrames consistently + # For the left DataFrame (batch), sort by timestamp (on column) + if timestamp_field in batch_clean.columns: batch_sorted = batch_clean.sort_values( - batch_sort_columns, ascending=True + timestamp_field, ascending=True ).reset_index(drop=True) + else: + batch_sorted = batch_clean.reset_index(drop=True) + # For the right DataFrame (features), sort by join keys (by columns) + timestamp (on column) + right_sort_columns = [] + + # Add join keys to sort columns (these are the 'by' columns for merge_asof) + for key in feature_join_keys: + if key in features_clean.columns: + right_sort_columns.append(key) + + # Add timestamp field to sort columns (this is the 'on' column for merge_asof) + if timestamp_field in features_clean.columns: + right_sort_columns.append(timestamp_field) + + # Sort the right DataFrame + if right_sort_columns: + # Remove duplicates first, then sort + features_clean = features_clean.drop_duplicates( + subset=right_sort_columns, keep="last" + ) features_sorted = features_clean.sort_values( - features_sort_columns, ascending=True + right_sort_columns, ascending=True ).reset_index(drop=True) + else: + features_sorted = features_clean.reset_index(drop=True) - # Verify sorting (merge_asof requirement) - for key in join_keys: - if not batch_sorted[key].is_monotonic_increasing: - # If not monotonic, we need to handle this differently - logger.warning( - f"Join key {key} is not monotonic, using manual join" + # Verify sorting for merge_asof + if ( + timestamp_field in features_sorted.columns + and len(features_sorted) > 1 + ): + # Check if timestamp is monotonic within each group + if feature_join_keys: + # Group by join keys and check if timestamp is monotonic within each group + grouped = features_sorted.groupby(feature_join_keys, sort=False) + for name, group in grouped: + if not group[timestamp_field].is_monotonic_increasing: + # If not monotonic, sort again more carefully + features_sorted = features_sorted.sort_values( + feature_join_keys + [timestamp_field], + ascending=True, + ).reset_index(drop=True) + break + else: + # No join keys, just check timestamp monotonicity + if not features_sorted[timestamp_field].is_monotonic_increasing: + features_sorted = features_sorted.sort_values( + timestamp_field, ascending=True + ).reset_index(drop=True) + + # Attempt merge_asof with proper error handling + try: + # Remove duplicates from both DataFrames before merge_asof + if feature_join_keys: + # For batch DataFrame, remove duplicates based on join keys + timestamp + batch_dedup_cols = [ + k for k in entity_join_keys if k in batch_sorted.columns + ] + if timestamp_field in batch_sorted.columns: + batch_dedup_cols.append(timestamp_field) + if batch_dedup_cols: + batch_sorted = batch_sorted.drop_duplicates( + subset=batch_dedup_cols, keep="last" ) - raise ValueError(f"Join key {key} is not monotonic") - if not features_sorted[key].is_monotonic_increasing: - logger.warning( - f"Feature join key {key} is not monotonic, using manual join" + + # For features DataFrame, remove duplicates based on join keys + timestamp + feature_dedup_cols = [ + k for k in feature_join_keys if k in features_sorted.columns + ] + if timestamp_field in features_sorted.columns: + feature_dedup_cols.append(timestamp_field) + if feature_dedup_cols: + features_sorted = features_sorted.drop_duplicates( + subset=feature_dedup_cols, keep="last" ) - raise ValueError(f"Feature join key {key} is not monotonic") # Perform merge_asof - result = pd.merge_asof( - batch_sorted, - features_sorted, - on=timestamp_field, - by=join_keys, - direction="backward", - ) - logger.debug( - f"merge_asof succeeded for batch of size {len(batch_sorted)}" - ) - - except (ValueError, KeyError) as e: - # If merge_asof fails, implement manual point-in-time join - logger.warning( - f"merge_asof failed, implementing manual point-in-time join: {e}" - ) - - # Group by join keys and apply point-in-time logic manually - result_chunks = [] - - logger.debug( - f"Manual join - batch_clean shape: {batch_clean.shape}" - ) - logger.debug( - f"Manual join - features_clean shape: {features_clean.shape}" - ) - logger.debug(f"Manual join - join_keys: {join_keys}") - logger.debug( - f"Manual join - feature_join_keys: {feature_join_keys}" - ) - - for join_key_vals, entity_group in batch_clean.groupby(join_keys): - # Create dictionary for filtering features by join keys - if len(join_keys) == 1: - entity_key_filter = {join_keys[0]: join_key_vals} - else: - entity_key_filter = dict(zip(join_keys, join_key_vals)) - - # For entity mapping, map the entity keys to feature keys - feature_key_filter = {} - for i, entity_key in enumerate(join_keys): - feature_key = ( - feature_join_keys[i] - if i < len(feature_join_keys) - else entity_key + if feature_join_keys: + # Handle join keys properly - if they are the same, just use one set + if entity_join_keys == feature_join_keys: + result = pd.merge_asof( + batch_sorted, + features_sorted, + on=timestamp_field, + by=entity_join_keys, + direction="backward", + suffixes=("", "_right"), ) - feature_key_filter[feature_key] = entity_key_filter[ - entity_key - ] - - # Filter features for this join key group - feature_group = features_clean - for key, val in feature_key_filter.items(): - # Only filter if the key exists in the feature dataset - if key in feature_group.columns: - feature_group = feature_group[feature_group[key] == val] - logger.debug( - f"Filtered by {key}={val}: {len(feature_group)} rows remaining" - ) - else: - logger.warning( - f"Join key {key} not found in feature dataset columns: {list(feature_group.columns)}" - ) - # If the key is missing, we can't match, so return empty - feature_group = feature_group.iloc[0:0] - break - - if len(feature_group) == 0: - # No features found, add NaN columns - entity_result = entity_group.copy() - for feat in requested_feats: - if feat not in entity_result.columns: - entity_result[feat] = np.nan - result_chunks.append(entity_result) else: - # Apply point-in-time logic: for each entity timestamp, find the latest feature - entity_result = entity_group.copy() - for feat in requested_feats: - if feat not in entity_result.columns: - entity_result[feat] = np.nan - - # For each row in entity group, find the latest feature value - for idx, entity_row in entity_group.iterrows(): - entity_ts = entity_row[timestamp_field] - # Find features with timestamp <= entity timestamp - valid_features = feature_group[ - feature_group[timestamp_field] <= entity_ts - ] - if len(valid_features) > 0: - # Sort by timestamp to ensure we get the latest feature - valid_features = valid_features.sort_values( - timestamp_field - ) - latest_feature = valid_features.iloc[-1] - # Update the result with feature values - for feat in requested_feats: - if feat in latest_feature: - entity_result.loc[idx, feat] = ( - latest_feature[feat] - ) - - result_chunks.append(entity_result) - - if result_chunks: - result = pd.concat(result_chunks, ignore_index=True) + # Different join keys, use left_by and right_by parameters + result = pd.merge_asof( + batch_sorted, + features_sorted, + on=timestamp_field, + left_by=entity_join_keys, + right_by=feature_join_keys, + direction="backward", + suffixes=("", "_right"), + ) else: - result = batch_clean.copy() - for feat in requested_feats: - if feat not in result.columns: - result[feat] = np.nan - - # Debug logging for join result - logger.debug(f"Join result shape: {result.shape}") - logger.debug(f"Join result columns: {list(result.columns)}") + result = pd.merge_asof( + batch_sorted, + features_sorted, + on=timestamp_field, + direction="backward", + suffixes=("", "_right"), + ) + except Exception as e: + logger.warning( + f"merge_asof failed: {e}, implementing manual point-in-time join" + ) + # Fall back to manual join + result = self._manual_point_in_time_join( + batch_clean, + features_clean, + entity_join_keys, + feature_join_keys, + timestamp_field, + requested_feats, + ) # Handle feature renaming if full_feature_names is True if full_feature_names and feature_view_name: for feat in requested_feats: @@ -590,6 +787,44 @@ def __init__( self._metadata: Optional[RetrievalMetadata] = None self._full_feature_names: bool = False self._on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None + self._feature_refs: List[str] = [] + self._entity_df: Optional[pd.DataFrame] = None + + def _create_metadata(self) -> RetrievalMetadata: + """Create metadata from the entity DataFrame and feature references.""" + if self._entity_df is not None: + # Get timestamp range from entity DataFrame + if "event_timestamp" in self._entity_df.columns: + timestamps = pd.to_datetime( + self._entity_df["event_timestamp"], utc=True + ) + min_timestamp = timestamps.min().to_pydatetime() + max_timestamp = timestamps.max().to_pydatetime() + + # Get keys (all columns except event_timestamp) + keys = [ + col for col in self._entity_df.columns if col != "event_timestamp" + ] + else: + min_timestamp = None + max_timestamp = None + keys = list(self._entity_df.columns) + else: + min_timestamp = None + max_timestamp = None + keys = [] + + return RetrievalMetadata( + features=self._feature_refs, + keys=keys, + min_event_timestamp=min_timestamp, + max_event_timestamp=max_timestamp, + ) + + def _set_metadata_info(self, feature_refs: List[str], entity_df: pd.DataFrame): + """Set the feature references and entity DataFrame for metadata creation.""" + self._feature_refs = feature_refs + self._entity_df = entity_df def _resolve(self) -> Union[Dataset, pd.DataFrame]: if callable(self._dataset_or_callable): @@ -605,28 +840,41 @@ def to_df( ) -> pd.DataFrame: # Use cached DataFrame if available for repeated access if self._cached_df is not None and not self.on_demand_feature_views: - return self._cached_df - - # If we have on-demand feature views, use the parent's implementation - # which calls to_arrow and applies the transformations - if self.on_demand_feature_views: - logger.info( - f"Using parent implementation for {len(self.on_demand_feature_views)} ODFVs" - ) - return super().to_df( - validation_reference=validation_reference, timeout=timeout - ) - - result = self._resolve() - if isinstance(result, pd.DataFrame): - self._cached_df = result - return result + df = self._cached_df + else: + # If we have on-demand feature views, use the parent's implementation + # which calls to_arrow and applies the transformations + if self.on_demand_feature_views: + logger.info( + f"Using parent implementation for {len(self.on_demand_feature_views)} ODFVs" + ) + df = super().to_df( + validation_reference=validation_reference, timeout=timeout + ) + else: + result = self._resolve() + if isinstance(result, pd.DataFrame): + df = result + else: + df = result.to_pandas() + self._cached_df = df - # Convert Ray Dataset to DataFrame with progress logging - logger.info("Converting Ray dataset to DataFrame...") - self._cached_df = result.to_pandas() - logger.info(f"Converted dataset to DataFrame: {self._cached_df.shape}") - return self._cached_df + # Handle validation reference if provided + if validation_reference: + try: + # Import here to avoid circular imports + from feast.dqm.errors import ValidationFailed + + # Run validation using the validation reference + validation_result = validation_reference.profile.validate(df) + if not validation_result.is_success: + raise ValidationFailed(validation_result) + except ImportError: + logger.warning("DQM profiler not available, skipping validation") + except Exception as e: + logger.error(f"Validation failed: {e}") + raise ValueError(f"Data validation failed: {e}") + return df def to_arrow( self, @@ -635,9 +883,6 @@ def to_arrow( ) -> pa.Table: # If we have ODFVs, use the parent's implementation if self.on_demand_feature_views: - logger.debug( - f"Using parent implementation for {len(self.on_demand_feature_views)} ODFVs" - ) return super().to_arrow( validation_reference=validation_reference, timeout=timeout ) @@ -673,6 +918,8 @@ def to_remote_storage(self) -> list[str]: @property def metadata(self) -> Optional[RetrievalMetadata]: """Return metadata information about retrieval.""" + if self._metadata is None: + self._metadata = self._create_metadata() return self._metadata @property @@ -692,21 +939,10 @@ def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table: result = self._resolve() if isinstance(result, pd.DataFrame): - logger.debug(f"_to_arrow_internal: DataFrame shape: {result.shape}") - logger.debug( - f"_to_arrow_internal: DataFrame columns: {list(result.columns)}" - ) return pa.Table.from_pandas(result) # For Ray Dataset, convert to pandas first then to arrow - logger.debug( - "_to_arrow_internal: Converting Ray Dataset to pandas then to arrow" - ) df = result.to_pandas() - logger.debug(f"_to_arrow_internal: Converted dataset shape: {df.shape}") - logger.debug( - f"_to_arrow_internal: Converted dataset columns: {list(df.columns)}" - ) return pa.Table.from_pandas(df) def persist( @@ -726,10 +962,20 @@ def persist( if not allow_overwrite and os.path.exists(destination_path): raise SavedDatasetLocationAlreadyExists(location=destination_path) try: - ds = self._resolve() + result = self._resolve() if not destination_path.startswith(("s3://", "gs://", "hdfs://")): os.makedirs(os.path.dirname(destination_path), exist_ok=True) - ds.write_parquet(destination_path) + + # Handle both DataFrame and Ray Dataset + if isinstance(result, pd.DataFrame): + # For DataFrame, convert to Ray Dataset first + RayOfflineStore._ensure_ray_initialized() + ds = ray.data.from_pandas(result) + ds.write_parquet(destination_path) + else: + # For Ray Dataset, use direct write + result.write_parquet(destination_path) + return destination_path except Exception as e: raise RuntimeError(f"Failed to persist dataset to {destination_path}: {e}") @@ -774,14 +1020,10 @@ def _ensure_ray_initialized(config: Optional[RepoConfig] = None): def _init_ray(self, config: RepoConfig): ray_config = config.offline_store assert isinstance(ray_config, RayOfflineStoreConfig) - self._ensure_ray_initialized(config) - - # Initialize optimization components if self._resource_manager is None: self._resource_manager = RayResourceManager(ray_config) self._resource_manager.configure_ray_context() - if self._data_processor is None: self._data_processor = RayDataProcessor(self._resource_manager) @@ -814,17 +1056,41 @@ def _create_filtered_dataset( if start_date or end_date: try: if start_date and end_date: - filtered_ds = ds.filter( - lambda row: start_date <= row[timestamp_field] <= end_date - ) + + def filter_func(row): + try: + ts = row[timestamp_field] + return start_date <= ts <= end_date + except KeyError: + raise KeyError( + f"Timestamp field '{timestamp_field}' not found in row. Available keys: {list(row.keys())}" + ) + + filtered_ds = ds.filter(filter_func) elif start_date: - filtered_ds = ds.filter( - lambda row: row[timestamp_field] >= start_date - ) + + def filter_func(row): + try: + ts = row[timestamp_field] + return ts >= start_date + except KeyError: + raise KeyError( + f"Timestamp field '{timestamp_field}' not found in row. Available keys: {list(row.keys())}" + ) + + filtered_ds = ds.filter(filter_func) elif end_date: - filtered_ds = ds.filter( - lambda row: row[timestamp_field] <= end_date - ) + + def filter_func(row): + try: + ts = row[timestamp_field] + return ts <= end_date + except KeyError: + raise KeyError( + f"Timestamp field '{timestamp_field}' not found in row. Available keys: {list(row.keys())}" + ) + + filtered_ds = ds.filter(filter_func) else: return ds @@ -887,11 +1153,6 @@ def get_historical_features( regular_feature_views = [ fv for fv in feature_views if fv.name not in odfv_names ] - - logger.info( - f"Processing {len(regular_feature_views)} regular feature views and {len(on_demand_feature_views)} on-demand feature views with {len(feature_refs)} feature references" - ) - # Apply field mappings to entity dataset if needed global_field_mappings = {} for fv in regular_feature_views: @@ -924,9 +1185,6 @@ def get_historical_features( ] if not fv_feature_refs: continue - - logger.info(f"Processing feature view: {fv.name}") - # Get join configuration entities = fv.entities or [] entity_objs = [registry.get_entity(e, project) for e in entities] @@ -957,10 +1215,6 @@ def get_historical_features( f"(available: {available_feature_names})" ) - logger.info( - f"Feature view '{fv.name}': requesting {requested_feats}, available: {available_feature_names}" - ) - # Load feature data as Ray dataset source_path = store._get_source_path(fv.batch_source, config) feature_ds = ray.data.read_parquet(source_path) @@ -979,13 +1233,6 @@ def get_historical_features( if created_col: created_col = field_mapping.get(created_col, created_col) - # Apply projection join key mapping to entity dataset if needed - if fv.projection.join_key_map: - # The feature dataset keeps its original columns (e.g., location_id) - # The entity dataset gets the mapped columns (e.g., origin_id, destination_id) - # We need to ensure the entity dataset has the properly mapped columns - pass # The entity dataset already has the mapped columns in this case - # Ensure timestamp compatibility in entity dataset if ( timestamp_field != "event_timestamp" @@ -1042,7 +1289,7 @@ def add_timestamp_field(batch: pd.DataFrame) -> pd.DataFrame: requested_feats, full_feature_names, fv.projection.name_to_use(), - original_join_keys, + original_join_keys if fv.projection.join_key_map else None, ) else: # Use distributed windowed join for large feature datasets @@ -1078,14 +1325,13 @@ def normalize_timestamps(batch: pd.DataFrame) -> pd.DataFrame: window_size=config.offline_store.window_size_for_joins, full_feature_names=full_feature_names, feature_view_name=fv.projection.name_to_use(), - original_join_keys=original_join_keys, + original_join_keys=original_join_keys + if fv.projection.join_key_map + else None, ) # Final processing: clean up and ensure proper column structure def finalize_result(batch: pd.DataFrame) -> pd.DataFrame: - logger.debug(f"Finalizing result - input columns: {list(batch.columns)}") - logger.debug(f"Finalizing result - batch shape: {batch.shape}") - batch = batch.copy() # Preserve existing feature columns (including renamed ones) @@ -1120,7 +1366,9 @@ def finalize_result(batch: pd.DataFrame) -> pd.DataFrame: elif timestamp_field in batch.columns: batch["event_timestamp"] = batch[timestamp_field] - logger.debug(f"Final columns: {list(batch.columns)}") + # Fix data types for feature columns using centralized type mapping utilities + batch = _convert_feature_column_types(batch, regular_feature_views) + return batch result_ds = result_ds.map_batches(finalize_result, batch_format="pandas") @@ -1134,8 +1382,9 @@ def finalize_result(batch: pd.DataFrame) -> pd.DataFrame: job = RayRetrievalJob(result_ds, staging_location=storage_path) job._full_feature_names = full_feature_names job._on_demand_feature_views = on_demand_feature_views - - logger.info("Historical features processing completed successfully") + job._feature_refs = feature_refs + job._entity_df = original_entity_df + job._metadata = job._create_metadata() return job def validate_data_source( @@ -1177,9 +1426,16 @@ def pull_latest_from_table_or_query( def _load(): try: - # Load and filter the dataset + # Get field mapping for column renaming after loading + field_mapping = getattr(data_source, "field_mapping", None) + + # The timestamp_field parameter is already the original field name + # (reverse mapping is handled by _get_column_names) + filter_timestamp_field = timestamp_field + + # Load and filter the dataset using the original timestamp field name ds = RayOfflineStore._create_filtered_dataset( - source_path, timestamp_field, start_date, end_date + source_path, filter_timestamp_field, start_date, end_date ) # Convert to pandas for deduplication and column selection @@ -1187,19 +1443,25 @@ def _load(): df = make_df_tzaware(df) # Apply field mapping if needed - field_mapping = getattr(data_source, "field_mapping", None) if field_mapping: df = df.rename(columns=field_mapping) - # Use the actual timestamp field name (this is already the correct mapped name) - timestamp_field_mapped = timestamp_field - created_timestamp_column_mapped = created_timestamp_column + # Now use the mapped timestamp field name for all operations + timestamp_field_mapped = ( + field_mapping.get(timestamp_field, timestamp_field) + if field_mapping + else timestamp_field + ) + created_timestamp_column_mapped = ( + field_mapping.get( + created_timestamp_column, created_timestamp_column + ) + if field_mapping and created_timestamp_column + else created_timestamp_column + ) # Handle empty DataFrame case if df.empty: - logger.info( - "DataFrame is empty after filtering, creating empty DataFrame with required columns" - ) # Create an empty DataFrame with the required columns empty_columns = ( join_key_columns @@ -1293,9 +1555,6 @@ def _load(): ): if timestamp_field_mapped in df.columns: df["event_timestamp"] = df[timestamp_field_mapped] - logger.debug( - f"Added 'event_timestamp' column from '{timestamp_field_mapped}' for pandas backend compatibility" - ) return df @@ -1328,9 +1587,16 @@ def pull_all_from_table_or_query( def _load(): try: - # Load and filter the dataset + # Get field mapping for column renaming after loading + field_mapping = getattr(data_source, "field_mapping", None) + + # The timestamp_field parameter is already the original field name + # (reverse mapping is handled by _get_column_names) + filter_timestamp_field = timestamp_field + + # Load and filter the dataset using the original timestamp field name ds = RayOfflineStore._create_filtered_dataset( - source_path, timestamp_field, start_date, end_date + source_path, filter_timestamp_field, start_date, end_date ) # Convert to pandas for column selection @@ -1338,27 +1604,25 @@ def _load(): df = make_df_tzaware(df) # Apply field mapping if needed - field_mapping = getattr(data_source, "field_mapping", None) if field_mapping: df = df.rename(columns=field_mapping) - # Use the actual timestamp field name (this is already the correct mapped name) - timestamp_field_mapped = timestamp_field - created_timestamp_column_mapped = created_timestamp_column - - # Debug logging - logger.debug(f"DataFrame columns: {df.columns.tolist()}") - logger.debug(f"Timestamp field: {timestamp_field_mapped}") - logger.debug( - f"Created timestamp column: {created_timestamp_column_mapped}" + # Now use the mapped timestamp field name for all operations + timestamp_field_mapped = ( + field_mapping.get(timestamp_field, timestamp_field) + if field_mapping + else timestamp_field + ) + created_timestamp_column_mapped = ( + field_mapping.get( + created_timestamp_column, created_timestamp_column + ) + if field_mapping and created_timestamp_column + else created_timestamp_column ) - logger.debug(f"DataFrame shape: {df.shape}") # Handle empty DataFrame case if df.empty: - logger.info( - "DataFrame is empty after filtering, creating empty DataFrame with required columns" - ) # Create an empty DataFrame with the required columns empty_columns = ( join_key_columns @@ -1440,9 +1704,6 @@ def _load(): ): if timestamp_field_mapped in df.columns: df["event_timestamp"] = df[timestamp_field_mapped] - logger.debug( - f"Added 'event_timestamp' column from '{timestamp_field_mapped}' for pandas backend compatibility" - ) return df @@ -1499,25 +1760,38 @@ def offline_write_batch( repo_path = getattr(config, "repo_path", None) or os.getcwd() ray_config = config.offline_store assert isinstance(ray_config, RayOfflineStoreConfig) - base_storage_path = ray_config.storage_path or "/tmp/ray-storage" - - batch_source_path = getattr(feature_view.batch_source, "file_path", None) - if not batch_source_path: - batch_source_path = f"{feature_view.name}/push_{_utc_now()}.parquet" + assert isinstance(feature_view.batch_source, FileSource) + # Use the existing batch source path (like other offline stores) + batch_source_path = feature_view.batch_source.file_options.uri feature_path = FileSource.get_uri_for_file_path(repo_path, batch_source_path) - storage_path = FileSource.get_uri_for_file_path(repo_path, base_storage_path) - - feature_dir = os.path.dirname(feature_path) - if not feature_dir.startswith(("s3://", "gs://")): - os.makedirs(feature_dir, exist_ok=True) - if not storage_path.startswith(("s3://", "gs://")): - os.makedirs(os.path.dirname(storage_path), exist_ok=True) - - df = table.to_pandas() - ds = ray.data.from_pandas(df) - ds.materialize() - ds.write_parquet(feature_dir) + + # If the path points to a file, write directly to that file location + # If it points to a directory, write to that directory + if feature_path.endswith(".parquet"): + # Convert PyArrow table to pandas DataFrame + df = table.to_pandas() + + # Check if file exists and append if it does + if os.path.exists(feature_path): + # Read existing data + existing_df = pd.read_parquet(feature_path) + # Append new data + combined_df = pd.concat([existing_df, df], ignore_index=True) + # Write combined data + combined_df.to_parquet(feature_path, index=False) + else: + # Write new data + df.to_parquet(feature_path, index=False) + else: + # Write to directory (multiple parquet files) + os.makedirs(feature_path, exist_ok=True) + + # Convert PyArrow table to Ray dataset + ds = ray.data.from_arrow(table) + + # Write to parquet + ds.write_parquet(feature_path) @staticmethod def create_saved_dataset_destination( From a806ae374dd99a1bdd1ad4679013e99cca2bf5a4 Mon Sep 17 00:00:00 2001 From: ntkathole Date: Thu, 10 Jul 2025 00:28:17 +0530 Subject: [PATCH 05/10] fix: Added utils for validation Signed-off-by: ntkathole --- .../contrib/ray_offline_store/ray.py | 809 +++++++++++------- 1 file changed, 486 insertions(+), 323 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py index e7109be8e31..de3ed55ce7f 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py @@ -27,6 +27,13 @@ RetrievalJob, RetrievalMetadata, ) +from feast.infra.offline_stores.offline_utils import ( + assert_expected_columns_in_entity_df, + get_entity_df_timestamp_bounds, + get_expected_join_keys, + get_pyarrow_schema_from_batch_source, + infer_event_timestamp_from_entity_df, +) from feast.infra.registry.base_registry import BaseRegistry from feast.on_demand_feature_view import OnDemandFeatureView from feast.repo_config import FeastConfigBaseModel, RepoConfig @@ -38,16 +45,292 @@ logger = logging.getLogger(__name__) +def _normalize_timestamp_column( + df: pd.DataFrame, column: str, inplace: bool = False +) -> pd.DataFrame: + """ + Normalize a timestamp column to UTC with second precision. + Args: + df: DataFrame containing the timestamp column + column: Name of the timestamp column to normalize + inplace: Whether to modify the DataFrame in place + Returns: + DataFrame with normalized timestamp column + """ + if not inplace: + df = df.copy() + + if column in df.columns: + df[column] = ( + pd.to_datetime(df[column], utc=True, errors="coerce") + .dt.floor("s") + .astype("datetime64[ns, UTC]") + ) + + return df + + +def _normalize_timestamp_columns( + df: pd.DataFrame, columns: List[str], inplace: bool = False +) -> pd.DataFrame: + """ + Normalize multiple timestamp columns to UTC with second precision. + Args: + df: DataFrame containing the timestamp columns + columns: List of timestamp column names to normalize + inplace: Whether to modify the DataFrame in place + Returns: + DataFrame with normalized timestamp columns + """ + if not inplace: + df = df.copy() + + for column in columns: + if column in df.columns: + df = _normalize_timestamp_column(df, column, inplace=True) + + return df + + +def _create_time_window_column( + df: pd.DataFrame, + timestamp_column: str, + window_size: str, + window_column: str = "time_window", + inplace: bool = False, +) -> pd.DataFrame: + """ + Create a time window column by flooring timestamps to specified window size. + Args: + df: DataFrame containing the timestamp column + timestamp_column: Name of the timestamp column + window_size: Window size string (e.g., "1H", "30min") + window_column: Name for the new window column + inplace: Whether to modify the DataFrame in place + Returns: + DataFrame with added time window column + """ + if not inplace: + df = df.copy() + + if timestamp_column in df.columns: + df[window_column] = ( + pd.to_datetime(df[timestamp_column]) + .dt.floor(window_size) + .astype("datetime64[ns, UTC]") + ) + + return df + + +def _create_empty_timestamp_column( + length: int, dtype: str = "datetime64[ns, UTC]" +) -> pd.Series: + """ + Create an empty timestamp column with proper dtype. + Args: + length: Length of the series + dtype: Pandas dtype for the timestamp column + Returns: + Series with NaT values and proper datetime dtype + """ + return pd.Series([pd.NaT] * length, dtype=dtype) + + +def _ensure_timestamp_compatibility( + df: pd.DataFrame, timestamp_fields: List[str], inplace: bool = False +) -> pd.DataFrame: + """ + Ensure timestamp columns have compatible dtypes and precision for joins. + Args: + df: DataFrame to process + timestamp_fields: List of timestamp field names + inplace: Whether to modify the DataFrame in place + Returns: + DataFrame with compatible timestamp columns + """ + if not inplace: + df = df.copy() + + # Use existing utility for timezone awareness + df = make_df_tzaware(df) + + # Then normalize timestamp precision for specified fields only + for field in timestamp_fields: + if field in df.columns: + df = _normalize_timestamp_column(df, field, inplace=True) + + return df + + +def _create_empty_dataframe_with_timestamp_columns( + columns: List[str], timestamp_columns: List[str] +) -> pd.DataFrame: + """ + Create an empty DataFrame with proper column types including datetime columns. + Args: + columns: List of all column names + timestamp_columns: List of timestamp column names that need proper dtype + Returns: + Empty DataFrame with proper column types + """ + df = pd.DataFrame(columns=columns) + + # Set proper dtype for timestamp columns + for col in timestamp_columns: + if col in df.columns: + df[col] = df[col].astype("datetime64[ns, UTC]") + + return df + + +def _safe_infer_event_timestamp_column( + entity_df: pd.DataFrame, fallback_column: str = "event_timestamp" +) -> str: + """ + Safely infer the event timestamp column using offline_utils with fallback. + Args: + entity_df: Entity DataFrame to analyze + fallback_column: Default column name to use if inference fails + Returns: + Inferred or fallback timestamp column name + """ + try: + return infer_event_timestamp_from_entity_df(entity_df.dtypes.to_dict()) + except Exception as e: + logger.debug( + f"Timestamp column inference failed: {e}, using fallback: {fallback_column}" + ) + return fallback_column + + +def _safe_get_entity_timestamp_bounds( + entity_df: pd.DataFrame, timestamp_column: str +) -> Tuple[Optional[datetime], Optional[datetime]]: + """ + Safely get entity timestamp bounds using offline_utils with fallback. + Args: + entity_df: Entity DataFrame + timestamp_column: Name of timestamp column + Returns: + Tuple of (min_timestamp, max_timestamp) or (None, None) if failed + """ + try: + if timestamp_column in entity_df.columns: + min_ts, max_ts = get_entity_df_timestamp_bounds(entity_df, timestamp_column) + # Convert Pandas Timestamp to datetime if needed + if hasattr(min_ts, "to_pydatetime"): + min_ts = min_ts.to_pydatetime() + if hasattr(max_ts, "to_pydatetime"): + max_ts = max_ts.to_pydatetime() + return min_ts, max_ts + except Exception as e: + logger.debug( + f"Timestamp bounds extraction failed: {e}, falling back to manual calculation" + ) + + # Fallback to original logic + try: + if timestamp_column in entity_df.columns: + timestamps = pd.to_datetime(entity_df[timestamp_column], utc=True) + return timestamps.min().to_pydatetime(), timestamps.max().to_pydatetime() + except Exception: + pass + + return None, None + + +def _safe_validate_entity_dataframe( + entity_df: pd.DataFrame, + feature_views: List[FeatureView], + project: str, + registry: BaseRegistry, +) -> None: + """ + Safely validate entity DataFrame using offline_utils with graceful fallback. + Args: + entity_df: Entity DataFrame to validate + feature_views: List of feature views to validate against + project: Feast project name + registry: Feature registry + """ + try: + # Get expected join keys for validation + expected_join_keys = get_expected_join_keys(project, feature_views, registry) + + # Infer event timestamp column + timestamp_col = infer_event_timestamp_from_entity_df(entity_df.dtypes.to_dict()) + + # Validate entity DataFrame has required columns + assert_expected_columns_in_entity_df( + entity_df.dtypes.to_dict(), expected_join_keys, timestamp_col + ) + + logger.info( + f"Entity DataFrame validation passed:\n" + f" Expected join keys: {expected_join_keys}\n" + f" Detected timestamp column: {timestamp_col}" + ) + + except Exception as e: + # Log validation issues but don't fail + logger.warning(f"Entity DataFrame validation skipped due to error: {e}") + logger.debug("Validation error details:", exc_info=True) + + +def _safe_validate_schema( + config: RepoConfig, + data_source: DataSource, + table_columns: List[str], + operation_name: str = "operation", +) -> Optional[Tuple[pa.Schema, List[str]]]: + """ + Safely validate schema using offline_utils with graceful fallback. + Args: + config: Repo configuration + data_source: Data source to validate against + table_columns: Actual table column names + operation_name: Name of operation for logging + Returns: + Tuple of (expected_schema, expected_columns) or None if validation fails + """ + try: + expected_schema, expected_columns = get_pyarrow_schema_from_batch_source( + config, data_source + ) + + if set(expected_columns) != set(table_columns): + logger.warning( + f"Schema mismatch in {operation_name}:\n" + f" Expected columns: {expected_columns}\n" + f" Actual columns: {table_columns}" + ) + + # Check if it's just a column order issue + if set(expected_columns) == set(table_columns): + logger.info(f"Columns match but order differs for {operation_name}") + return expected_schema, expected_columns + else: + logger.debug(f"Schema validation passed for {operation_name}") + return expected_schema, expected_columns + + except Exception as e: + logger.warning( + f"Schema validation skipped for {operation_name} due to error: {e}" + ) + logger.debug("Schema validation error details:", exc_info=True) + + return None + + def _convert_feature_column_types( batch: pd.DataFrame, feature_views: List[FeatureView] ) -> pd.DataFrame: """ Convert feature columns to appropriate pandas types using Feast's type mapping utilities. - Args: batch: DataFrame containing feature data feature_views: List of feature views with type information - Returns: DataFrame with properly converted feature column types """ @@ -312,9 +595,7 @@ def _manual_point_in_time_join( if feat in features_df.columns and pd.api.types.is_datetime64_any_dtype( features_df[feat] ): - result[feat] = pd.Series( - [pd.NaT] * len(result), dtype="datetime64[ns, UTC]" - ) + result[feat] = _create_empty_timestamp_column(len(result)) else: result[feat] = np.nan @@ -439,21 +720,10 @@ def join_batch_with_features(batch: pd.DataFrame) -> pd.DataFrame: features_filtered = features[available_feature_cols].copy() # Ensure timestamp columns have compatible dtypes and precision - if timestamp_field in batch.columns: - batch[timestamp_field] = ( - pd.to_datetime(batch[timestamp_field], utc=True, errors="coerce") - .dt.floor("s") - .astype("datetime64[ns, UTC]") - ) - - if timestamp_field in features_filtered.columns: - features_filtered[timestamp_field] = ( - pd.to_datetime( - features_filtered[timestamp_field], utc=True, errors="coerce" - ) - .dt.floor("s") - .astype("datetime64[ns, UTC]") - ) + batch = _normalize_timestamp_column(batch, timestamp_field, inplace=True) + features_filtered = _normalize_timestamp_column( + features_filtered, timestamp_field, inplace=True + ) if not entity_join_keys: # Temporal join without entity keys @@ -688,11 +958,8 @@ def _add_time_windows_and_source_marker( """Add time windows and source markers to dataset.""" def add_window_and_source(batch: pd.DataFrame) -> pd.DataFrame: - batch = batch.copy() - batch["time_window"] = ( - pd.to_datetime(batch[timestamp_field]) - .dt.floor(window_size) - .astype("datetime64[ns, UTC]") + batch = _create_time_window_column( + batch, timestamp_field, window_size, "time_window" ) batch["_data_source"] = source_marker return batch @@ -793,22 +1060,16 @@ def __init__( def _create_metadata(self) -> RetrievalMetadata: """Create metadata from the entity DataFrame and feature references.""" if self._entity_df is not None: - # Get timestamp range from entity DataFrame - if "event_timestamp" in self._entity_df.columns: - timestamps = pd.to_datetime( - self._entity_df["event_timestamp"], utc=True - ) - min_timestamp = timestamps.min().to_pydatetime() - max_timestamp = timestamps.max().to_pydatetime() + # Auto-detect timestamp column and get timestamp bounds using utilities + timestamp_col = _safe_infer_event_timestamp_column( + self._entity_df, "event_timestamp" + ) + min_timestamp, max_timestamp = _safe_get_entity_timestamp_bounds( + self._entity_df, timestamp_col + ) - # Get keys (all columns except event_timestamp) - keys = [ - col for col in self._entity_df.columns if col != "event_timestamp" - ] - else: - min_timestamp = None - max_timestamp = None - keys = list(self._entity_df.columns) + # Get keys (all columns except the detected timestamp column) + keys = [col for col in self._entity_df.columns if col != timestamp_col] else: min_timestamp = None max_timestamp = None @@ -1121,16 +1382,10 @@ def get_historical_features( entity_ds = ray.data.from_pandas(entity_df) original_entity_df = entity_df.copy() - # Make entity dataframe timezone aware - original_entity_df = make_df_tzaware(original_entity_df) - if "event_timestamp" in original_entity_df.columns: - original_entity_df["event_timestamp"] = ( - pd.to_datetime( - original_entity_df["event_timestamp"], utc=True, errors="coerce" - ) - .dt.floor("s") - .astype("datetime64[ns, UTC]") - ) + # Make entity dataframe timezone aware and normalize timestamp + original_entity_df = _ensure_timestamp_compatibility( + original_entity_df, ["event_timestamp"] + ) # Parse feature_refs and get ODFVs on_demand_feature_views = OnDemandFeatureView.get_requested_odfvs( @@ -1153,6 +1408,11 @@ def get_historical_features( regular_feature_views = [ fv for fv in feature_views if fv.name not in odfv_names ] + + # Enhanced validation using offline_utils with safe fallback + _safe_validate_entity_dataframe( + original_entity_df, regular_feature_views, project, registry + ) # Apply field mappings to entity dataset if needed global_field_mappings = {} for fv in regular_feature_views: @@ -1242,14 +1502,10 @@ def get_historical_features( def add_timestamp_field(batch: pd.DataFrame) -> pd.DataFrame: batch = batch.copy() - batch[timestamp_field] = ( - pd.to_datetime( - batch["event_timestamp"], utc=True, errors="coerce" - ) - .dt.floor("s") - .astype("datetime64[ns, UTC]") + batch[timestamp_field] = batch["event_timestamp"] + return _normalize_timestamp_column( + batch, timestamp_field, inplace=True ) - return batch result_ds = result_ds.map_batches( add_timestamp_field, batch_format="pandas" @@ -1268,16 +1524,9 @@ def add_timestamp_field(batch: pd.DataFrame) -> pd.DataFrame: f"Using broadcast join for {fv.name} (size: {feature_size // 1024**2}MB)" ) feature_df = feature_ds.to_pandas() - feature_df = make_df_tzaware(feature_df) - - if timestamp_field in feature_df.columns: - feature_df[timestamp_field] = ( - pd.to_datetime( - feature_df[timestamp_field], utc=True, errors="coerce" - ) - .dt.floor("s") - .astype("datetime64[ns, UTC]") - ) + feature_df = _ensure_timestamp_compatibility( + feature_df, [timestamp_field] + ) if store._data_processor is None: raise ValueError("Data processor not initialized") @@ -1299,16 +1548,7 @@ def add_timestamp_field(batch: pd.DataFrame) -> pd.DataFrame: # Ensure timestamp format in feature dataset def normalize_timestamps(batch: pd.DataFrame) -> pd.DataFrame: - batch = make_df_tzaware(batch) - if timestamp_field in batch.columns: - batch[timestamp_field] = ( - pd.to_datetime( - batch[timestamp_field], utc=True, errors="coerce" - ) - .dt.floor("s") - .astype("datetime64[ns, UTC]") - ) - return batch + return _ensure_timestamp_compatibility(batch, [timestamp_field]) feature_ds = feature_ds.map_batches( normalize_timestamps, batch_format="pandas" @@ -1355,13 +1595,10 @@ def finalize_result(batch: pd.DataFrame) -> pd.DataFrame: if "event_timestamp" not in batch.columns: if "event_timestamp" in original_entity_df.columns: batch["event_timestamp"] = ( - pd.to_datetime( - original_entity_df["event_timestamp"].iloc[: len(batch)], - utc=True, - errors="coerce", - ) - .dt.floor("s") - .astype("datetime64[ns, UTC]") + original_entity_df["event_timestamp"].iloc[: len(batch)].values + ) + batch = _normalize_timestamp_column( + batch, "event_timestamp", inplace=True ) elif timestamp_field in batch.columns: batch["event_timestamp"] = batch[timestamp_field] @@ -1408,6 +1645,118 @@ def supports_remote_storage_export(self) -> bool: """Check if remote storage export is supported.""" return self._staging_location is not None + @staticmethod + def _load_and_filter_dataset( + source_path: str, + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_field: str, + created_timestamp_column: Optional[str], + start_date: Optional[datetime], + end_date: Optional[datetime], + ) -> pd.DataFrame: + """ + Common method to load and filter dataset for both pull_latest and pull_all methods. + Args: + source_path: Path to the data source + data_source: DataSource object containing field mapping + join_key_columns: List of join key columns + feature_name_columns: List of feature columns + timestamp_field: Name of the timestamp field + created_timestamp_column: Optional created timestamp column + start_date: Optional start date for filtering + end_date: Optional end date for filtering + Returns: + Processed pandas DataFrame + """ + try: + # Get field mapping for column renaming after loading + field_mapping = getattr(data_source, "field_mapping", None) + + # Load and filter the dataset using the original timestamp field name + ds = RayOfflineStore._create_filtered_dataset( + source_path, timestamp_field, start_date, end_date + ) + + # Convert to pandas for processing + df = ds.to_pandas() + df = make_df_tzaware(df) + + # Apply field mapping if needed + if field_mapping: + df = df.rename(columns=field_mapping) + + # Get mapped field names + timestamp_field_mapped = ( + field_mapping.get(timestamp_field, timestamp_field) + if field_mapping + else timestamp_field + ) + created_timestamp_column_mapped = ( + field_mapping.get(created_timestamp_column, created_timestamp_column) + if field_mapping and created_timestamp_column + else created_timestamp_column + ) + + # Build timestamp columns list + timestamp_columns = [timestamp_field_mapped] + if created_timestamp_column_mapped: + timestamp_columns.append(created_timestamp_column_mapped) + + # Normalize timestamp columns + df = _normalize_timestamp_columns(df, timestamp_columns, inplace=True) + + # Handle empty DataFrame case + if df.empty: + empty_columns = ( + join_key_columns + feature_name_columns + timestamp_columns + ) + if not join_key_columns: + empty_columns.append(DUMMY_ENTITY_ID) + if "event_timestamp" not in empty_columns: + empty_columns.append("event_timestamp") + return _create_empty_dataframe_with_timestamp_columns( + empty_columns, timestamp_columns + ) + + # Build required columns list + all_required_columns = ( + join_key_columns + feature_name_columns + timestamp_columns + ) + if not join_key_columns: + df[DUMMY_ENTITY_ID] = DUMMY_ENTITY_VAL + all_required_columns.append(DUMMY_ENTITY_ID) + + # Select only the required columns that exist + available_columns = [ + col for col in all_required_columns if col in df.columns + ] + df = df[available_columns] + + # Basic sorting by timestamp (most recent first) + existing_timestamp_columns = [ + col for col in timestamp_columns if col in df.columns + ] + if existing_timestamp_columns: + df = df.sort_values(existing_timestamp_columns, ascending=False) + + # Reset index + df = df.reset_index(drop=True) + + # Ensure 'event_timestamp' column exists for pandas backend compatibility + if ( + "event_timestamp" not in df.columns + and timestamp_field_mapped != "event_timestamp" + ): + if timestamp_field_mapped in df.columns: + df["event_timestamp"] = df[timestamp_field_mapped] + + return df + + except Exception as e: + raise RuntimeError(f"Failed to load data from {source_path}: {e}") + @staticmethod def pull_latest_from_table_or_query( config: RepoConfig, @@ -1425,28 +1774,22 @@ def pull_latest_from_table_or_query( source_path = store._get_source_path(data_source, config) def _load(): - try: - # Get field mapping for column renaming after loading - field_mapping = getattr(data_source, "field_mapping", None) - - # The timestamp_field parameter is already the original field name - # (reverse mapping is handled by _get_column_names) - filter_timestamp_field = timestamp_field - - # Load and filter the dataset using the original timestamp field name - ds = RayOfflineStore._create_filtered_dataset( - source_path, filter_timestamp_field, start_date, end_date - ) - - # Convert to pandas for deduplication and column selection - df = ds.to_pandas() - df = make_df_tzaware(df) - - # Apply field mapping if needed - if field_mapping: - df = df.rename(columns=field_mapping) + # Load and filter the dataset using the shared method + df = store._load_and_filter_dataset( + source_path, + data_source, + join_key_columns, + feature_name_columns, + timestamp_field, + created_timestamp_column, + start_date, + end_date, + ) - # Now use the mapped timestamp field name for all operations + # Handle deduplication (keep latest records) - specific to pull_latest + if join_key_columns and not df.empty: + # Get field mapping for proper column names + field_mapping = getattr(data_source, "field_mapping", None) timestamp_field_mapped = ( field_mapping.get(timestamp_field, timestamp_field) if field_mapping @@ -1460,106 +1803,27 @@ def _load(): else created_timestamp_column ) - # Handle empty DataFrame case - if df.empty: - # Create an empty DataFrame with the required columns - empty_columns = ( - join_key_columns - + feature_name_columns - + [timestamp_field_mapped] - ) - if created_timestamp_column_mapped: - empty_columns.append(created_timestamp_column_mapped) - if not join_key_columns: - empty_columns.append(DUMMY_ENTITY_ID) - - # Add event_timestamp column for pandas backend compatibility - if "event_timestamp" not in empty_columns: - empty_columns.append("event_timestamp") - - # Create empty DataFrame with proper column types - empty_df = pd.DataFrame(columns=empty_columns) - return empty_df - - # Ensure timestamp is properly formatted - if timestamp_field_mapped in df.columns: - df[timestamp_field_mapped] = ( - pd.to_datetime( - df[timestamp_field_mapped], utc=True, errors="coerce" - ) - .dt.floor("s") - .astype("datetime64[ns, UTC]") - ) - - if ( - created_timestamp_column_mapped - and created_timestamp_column_mapped in df.columns - ): - df[created_timestamp_column_mapped] = ( - pd.to_datetime( - df[created_timestamp_column_mapped], - utc=True, - errors="coerce", - ) - .dt.floor("s") - .astype("datetime64[ns, UTC]") - ) - - # Prepare columns to select + # Build timestamp columns for sorting timestamp_columns = [timestamp_field_mapped] if created_timestamp_column_mapped: timestamp_columns.append(created_timestamp_column_mapped) - all_required_columns = ( - join_key_columns + feature_name_columns + timestamp_columns - ) - - # Select only the required columns that exist - available_columns = [ - col for col in all_required_columns if col in df.columns + # Filter out timestamp columns that don't exist in the dataframe + existing_timestamp_columns = [ + col for col in timestamp_columns if col in df.columns ] - df = df[available_columns] - - # Handle deduplication (keep latest records) - if join_key_columns: - # Sort by timestamp columns (latest first) and deduplicate by join keys - # Filter out timestamp columns that don't exist in the dataframe - existing_timestamp_columns = [ - col for col in timestamp_columns if col in df.columns - ] - sort_columns = join_key_columns + existing_timestamp_columns - if sort_columns: - df = df.sort_values( - sort_columns, - ascending=[True] * len(join_key_columns) - + [False] * len(existing_timestamp_columns), - ) - df = df.drop_duplicates(subset=join_key_columns, keep="first") - else: - # No join keys - add dummy entity and sort by timestamp - df[DUMMY_ENTITY_ID] = DUMMY_ENTITY_VAL - # Filter out timestamp columns that don't exist in the dataframe - existing_timestamp_columns = [ - col for col in timestamp_columns if col in df.columns - ] - if existing_timestamp_columns: - df = df.sort_values(existing_timestamp_columns, ascending=False) - - # Reset index - df = df.reset_index(drop=True) - - # Ensure 'event_timestamp' column exists for pandas backend compatibility - if ( - "event_timestamp" not in df.columns - and timestamp_field_mapped != "event_timestamp" - ): - if timestamp_field_mapped in df.columns: - df["event_timestamp"] = df[timestamp_field_mapped] - return df + # Sort by join keys (ascending) and timestamps (descending for latest first) + sort_columns = join_key_columns + existing_timestamp_columns + if sort_columns: + df = df.sort_values( + sort_columns, + ascending=[True] * len(join_key_columns) + + [False] * len(existing_timestamp_columns), + ) + df = df.drop_duplicates(subset=join_key_columns, keep="first") - except Exception as e: - raise RuntimeError(f"Failed to load data from {source_path}: {e}") + return df return RayRetrievalJob( _load, staging_location=config.offline_store.storage_path @@ -1586,129 +1850,16 @@ def pull_all_from_table_or_query( raise FileNotFoundError(f"Parquet path does not exist: {source_path}") def _load(): - try: - # Get field mapping for column renaming after loading - field_mapping = getattr(data_source, "field_mapping", None) - - # The timestamp_field parameter is already the original field name - # (reverse mapping is handled by _get_column_names) - filter_timestamp_field = timestamp_field - - # Load and filter the dataset using the original timestamp field name - ds = RayOfflineStore._create_filtered_dataset( - source_path, filter_timestamp_field, start_date, end_date - ) - - # Convert to pandas for column selection - df = ds.to_pandas() - df = make_df_tzaware(df) - - # Apply field mapping if needed - if field_mapping: - df = df.rename(columns=field_mapping) - - # Now use the mapped timestamp field name for all operations - timestamp_field_mapped = ( - field_mapping.get(timestamp_field, timestamp_field) - if field_mapping - else timestamp_field - ) - created_timestamp_column_mapped = ( - field_mapping.get( - created_timestamp_column, created_timestamp_column - ) - if field_mapping and created_timestamp_column - else created_timestamp_column - ) - - # Handle empty DataFrame case - if df.empty: - # Create an empty DataFrame with the required columns - empty_columns = ( - join_key_columns - + feature_name_columns - + [timestamp_field_mapped] - ) - if created_timestamp_column_mapped: - empty_columns.append(created_timestamp_column_mapped) - if not join_key_columns: - empty_columns.append(DUMMY_ENTITY_ID) - - # Add event_timestamp column for pandas backend compatibility - if "event_timestamp" not in empty_columns: - empty_columns.append("event_timestamp") - - # Create empty DataFrame with proper column types - empty_df = pd.DataFrame(columns=empty_columns) - return empty_df - - # Ensure timestamp is properly formatted - if timestamp_field_mapped in df.columns: - df[timestamp_field_mapped] = ( - pd.to_datetime( - df[timestamp_field_mapped], utc=True, errors="coerce" - ) - .dt.floor("s") - .astype("datetime64[ns, UTC]") - ) - - if ( - created_timestamp_column_mapped - and created_timestamp_column_mapped in df.columns - ): - df[created_timestamp_column_mapped] = ( - pd.to_datetime( - df[created_timestamp_column_mapped], - utc=True, - errors="coerce", - ) - .dt.floor("s") - .astype("datetime64[ns, UTC]") - ) - - # Prepare columns to select - timestamp_columns = [timestamp_field_mapped] - if created_timestamp_column_mapped: - timestamp_columns.append(created_timestamp_column_mapped) - - all_required_columns = ( - join_key_columns + feature_name_columns + timestamp_columns - ) - - # Add dummy entity if no join keys - if not join_key_columns: - df[DUMMY_ENTITY_ID] = DUMMY_ENTITY_VAL - all_required_columns.append(DUMMY_ENTITY_ID) - - # Select only the required columns that exist - available_columns = [ - col for col in all_required_columns if col in df.columns - ] - df = df[available_columns] - - # Sort by timestamp (most recent first) - # Filter out timestamp columns that don't exist in the dataframe - existing_timestamp_columns = [ - col for col in timestamp_columns if col in df.columns - ] - if existing_timestamp_columns: - df = df.sort_values(existing_timestamp_columns, ascending=False) - - # Reset index - df = df.reset_index(drop=True) - - # Ensure 'event_timestamp' column exists for pandas backend compatibility - if ( - "event_timestamp" not in df.columns - and timestamp_field_mapped != "event_timestamp" - ): - if timestamp_field_mapped in df.columns: - df["event_timestamp"] = df[timestamp_field_mapped] - - return df - - except Exception as e: - raise RuntimeError(f"Failed to load data from {source_path}: {e}") + return store._load_and_filter_dataset( + source_path, + data_source, + join_key_columns, + feature_name_columns, + timestamp_field, + created_timestamp_column, + start_date, + end_date, + ) return RayRetrievalJob( _load, staging_location=config.offline_store.storage_path @@ -1762,7 +1913,20 @@ def offline_write_batch( assert isinstance(ray_config, RayOfflineStoreConfig) assert isinstance(feature_view.batch_source, FileSource) - # Use the existing batch source path (like other offline stores) + # Enhanced schema validation using safe utility + validation_result = _safe_validate_schema( + config, feature_view.batch_source, table.column_names, "offline_write_batch" + ) + + if validation_result: + expected_schema, expected_columns = validation_result + # Try to reorder columns to match expected order if needed + if expected_columns != table.column_names and set(expected_columns) == set( + table.column_names + ): + logger.info("Reordering table columns to match expected schema") + table = table.select(expected_columns) + batch_source_path = feature_view.batch_source.file_options.uri feature_path = FileSource.get_uri_for_file_path(repo_path, batch_source_path) @@ -1802,7 +1966,6 @@ def create_saved_dataset_destination( """Create a saved dataset destination for Ray offline store.""" if path is None: - # Use default path based on config ray_config = config.offline_store assert isinstance(ray_config, RayOfflineStoreConfig) base_storage_path = ray_config.storage_path or "/tmp/ray-storage" From b31a8fd96e7defeaf1e7dd512c166c3969d24994 Mon Sep 17 00:00:00 2001 From: ntkathole Date: Thu, 10 Jul 2025 15:13:00 +0530 Subject: [PATCH 06/10] feat: Use Ray Dataset for data processing Signed-off-by: ntkathole --- .../contrib/ray_offline_store/ray.py | 1879 +++++++++++------ 1 file changed, 1238 insertions(+), 641 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py index de3ed55ce7f..b41d6e45a8b 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py @@ -38,165 +38,218 @@ from feast.on_demand_feature_view import OnDemandFeatureView from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage, ValidationReference -from feast.type_map import feast_value_type_to_pandas_type +from feast.type_map import feast_value_type_to_pandas_type, pa_to_feast_value_type from feast.utils import _get_column_names, make_df_tzaware from feast.value_type import ValueType logger = logging.getLogger(__name__) -def _normalize_timestamp_column( - df: pd.DataFrame, column: str, inplace: bool = False -) -> pd.DataFrame: +def _get_data_schema_info( + data: Union[pd.DataFrame, Dataset], +) -> Tuple[Dict[str, Any], List[str]]: """ - Normalize a timestamp column to UTC with second precision. + Extract schema information from DataFrame or Dataset. Args: - df: DataFrame containing the timestamp column - column: Name of the timestamp column to normalize - inplace: Whether to modify the DataFrame in place + data: DataFrame or Ray Dataset Returns: - DataFrame with normalized timestamp column + Tuple of (dtypes_dict, column_names) """ - if not inplace: - df = df.copy() - - if column in df.columns: - df[column] = ( - pd.to_datetime(df[column], utc=True, errors="coerce") - .dt.floor("s") - .astype("datetime64[ns, UTC]") - ) - - return df + if isinstance(data, Dataset): + schema = data.schema() + # Embed _create_dtypes_dict_from_schema logic inline + dtypes = {} + for i, col in enumerate(schema.names): + field_type = schema.field(i).type + # Embed _pa_type_to_pandas_dtype logic inline + try: + pa_type_str = str(field_type).lower() + feast_value_type = pa_to_feast_value_type(pa_type_str) + pandas_type_str = feast_value_type_to_pandas_type(feast_value_type) + dtypes[col] = pd.api.types.pandas_dtype(pandas_type_str) + except Exception: + dtypes[col] = pd.api.types.pandas_dtype("object") + columns = schema.names + else: + dtypes = data.dtypes.to_dict() + columns = list(data.columns) + return dtypes, columns -def _normalize_timestamp_columns( - df: pd.DataFrame, columns: List[str], inplace: bool = False -) -> pd.DataFrame: +def _apply_to_data( + data: Union[pd.DataFrame, Dataset], + process_func: Callable[[pd.DataFrame], pd.DataFrame], + inplace: bool = False, +) -> Union[pd.DataFrame, Dataset]: """ - Normalize multiple timestamp columns to UTC with second precision. + Apply a processing function to DataFrame or Dataset. Args: - df: DataFrame containing the timestamp columns - columns: List of timestamp column names to normalize - inplace: Whether to modify the DataFrame in place + data: DataFrame or Ray Dataset to process + process_func: Function that takes a DataFrame and returns a processed DataFrame + inplace: Whether to modify DataFrame in place (only applies to pandas) Returns: - DataFrame with normalized timestamp columns + Processed DataFrame or Dataset """ - if not inplace: - df = df.copy() - - for column in columns: - if column in df.columns: - df = _normalize_timestamp_column(df, column, inplace=True) - - return df + if isinstance(data, Dataset): + return data.map_batches(process_func, batch_format="pandas") + else: + if not inplace: + data = data.copy() + return process_func(data) -def _create_time_window_column( - df: pd.DataFrame, - timestamp_column: str, - window_size: str, - window_column: str = "time_window", +def _normalize_timestamp_columns( + data: Union[pd.DataFrame, Dataset], + columns: Union[str, List[str]], inplace: bool = False, -) -> pd.DataFrame: +) -> Union[pd.DataFrame, Dataset]: """ - Create a time window column by flooring timestamps to specified window size. + Normalize timestamp columns to UTC with second precision. + Works with both pandas DataFrames and Ray Datasets. Args: - df: DataFrame containing the timestamp column - timestamp_column: Name of the timestamp column - window_size: Window size string (e.g., "1H", "30min") - window_column: Name for the new window column - inplace: Whether to modify the DataFrame in place + data: DataFrame or Ray Dataset containing the timestamp columns + columns: Column name (str) or list of column names (List[str]) to normalize + inplace: Whether to modify the DataFrame in place (only applies to pandas) Returns: - DataFrame with added time window column + DataFrame or Dataset with normalized timestamp columns """ - if not inplace: - df = df.copy() + # Normalize input to always be a list + column_list = [columns] if isinstance(columns, str) else columns - if timestamp_column in df.columns: - df[window_column] = ( - pd.to_datetime(df[timestamp_column]) - .dt.floor(window_size) + def apply_normalization(series: pd.Series) -> pd.Series: + return ( + pd.to_datetime(series, utc=True, errors="coerce") + .dt.floor("s") .astype("datetime64[ns, UTC]") ) - return df + if isinstance(data, Dataset): + def normalize_batch(batch: pd.DataFrame) -> pd.DataFrame: + for column in column_list: + if not batch.empty and column in batch.columns: + batch[column] = apply_normalization(batch[column]) + return batch -def _create_empty_timestamp_column( - length: int, dtype: str = "datetime64[ns, UTC]" -) -> pd.Series: - """ - Create an empty timestamp column with proper dtype. - Args: - length: Length of the series - dtype: Pandas dtype for the timestamp column - Returns: - Series with NaT values and proper datetime dtype - """ - return pd.Series([pd.NaT] * length, dtype=dtype) + return data.map_batches(normalize_batch, batch_format="pandas") + else: + if not inplace: + data = data.copy() + + for column in column_list: + if column in data.columns: + data[column] = apply_normalization(data[column]) + return data def _ensure_timestamp_compatibility( - df: pd.DataFrame, timestamp_fields: List[str], inplace: bool = False -) -> pd.DataFrame: + data: Union[pd.DataFrame, Dataset], + timestamp_fields: List[str], + inplace: bool = False, +) -> Union[pd.DataFrame, Dataset]: """ Ensure timestamp columns have compatible dtypes and precision for joins. + Works with both pandas DataFrames and Ray Datasets. Args: - df: DataFrame to process + data: DataFrame or Ray Dataset to process timestamp_fields: List of timestamp field names - inplace: Whether to modify the DataFrame in place + inplace: Whether to modify the DataFrame in place (only applies to pandas) Returns: - DataFrame with compatible timestamp columns + DataFrame or Dataset with compatible timestamp columns """ - if not inplace: - df = df.copy() + if isinstance(data, Dataset): + # Ray Dataset path + def ensure_compatibility(batch: pd.DataFrame) -> pd.DataFrame: + # Use existing utility for timezone awareness + batch = make_df_tzaware(batch) + + # Then normalize timestamp precision for specified fields only + for field in timestamp_fields: + if field in batch.columns: + batch[field] = ( + pd.to_datetime(batch[field], utc=True, errors="coerce") + .dt.floor("s") + .astype("datetime64[ns, UTC]") + ) + return batch - # Use existing utility for timezone awareness - df = make_df_tzaware(df) + return data.map_batches(ensure_compatibility, batch_format="pandas") + else: + # Pandas DataFrame path + if not inplace: + data = data.copy() - # Then normalize timestamp precision for specified fields only - for field in timestamp_fields: - if field in df.columns: - df = _normalize_timestamp_column(df, field, inplace=True) + # Use existing utility for timezone awareness + data = make_df_tzaware(data) - return df + # Then normalize timestamp precision for specified fields only + for field in timestamp_fields: + if field in data.columns: + data = _normalize_timestamp_columns(data, field, inplace=True) + return data -def _create_empty_dataframe_with_timestamp_columns( - columns: List[str], timestamp_columns: List[str] +def _build_required_columns( + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_columns: List[str], +) -> List[str]: + """ + Build list of required columns for data processing. + Args: + join_key_columns: List of join key columns + feature_name_columns: List of feature columns + timestamp_columns: List of timestamp columns + Returns: + List of all required columns + """ + all_required_columns = join_key_columns + feature_name_columns + timestamp_columns + if not join_key_columns: + all_required_columns.append(DUMMY_ENTITY_ID) + if "event_timestamp" not in all_required_columns: + all_required_columns.append("event_timestamp") + return all_required_columns + + +def _handle_empty_dataframe_case( + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_columns: List[str], ) -> pd.DataFrame: """ - Create an empty DataFrame with proper column types including datetime columns. + Handle empty DataFrame case by creating properly structured empty DataFrame. Args: - columns: List of all column names - timestamp_columns: List of timestamp column names that need proper dtype + join_key_columns: List of join key columns + feature_name_columns: List of feature columns + timestamp_columns: List of timestamp columns Returns: - Empty DataFrame with proper column types + Empty DataFrame with proper structure and column types """ - df = pd.DataFrame(columns=columns) - - # Set proper dtype for timestamp columns + empty_columns = _build_required_columns( + join_key_columns, feature_name_columns, timestamp_columns + ) + df = pd.DataFrame(columns=empty_columns) for col in timestamp_columns: if col in df.columns: df[col] = df[col].astype("datetime64[ns, UTC]") - return df def _safe_infer_event_timestamp_column( - entity_df: pd.DataFrame, fallback_column: str = "event_timestamp" + data: Union[pd.DataFrame, Dataset], fallback_column: str = "event_timestamp" ) -> str: """ - Safely infer the event timestamp column using offline_utils with fallback. + Safely infer the event timestamp column. + Works with both pandas DataFrames and Ray Datasets. Args: - entity_df: Entity DataFrame to analyze + data: DataFrame or Ray Dataset to analyze fallback_column: Default column name to use if inference fails Returns: Inferred or fallback timestamp column name """ try: - return infer_event_timestamp_from_entity_df(entity_df.dtypes.to_dict()) + dtypes, _ = _get_data_schema_info(data) + return infer_event_timestamp_from_entity_df(dtypes) except Exception as e: logger.debug( f"Timestamp column inference failed: {e}, using fallback: {fallback_column}" @@ -205,51 +258,92 @@ def _safe_infer_event_timestamp_column( def _safe_get_entity_timestamp_bounds( - entity_df: pd.DataFrame, timestamp_column: str + data: Union[pd.DataFrame, Dataset], timestamp_column: str ) -> Tuple[Optional[datetime], Optional[datetime]]: """ - Safely get entity timestamp bounds using offline_utils with fallback. + Safely get entity timestamp bounds. + Works with both pandas DataFrames and Ray Datasets. Args: - entity_df: Entity DataFrame + data: DataFrame or Ray Dataset timestamp_column: Name of timestamp column Returns: Tuple of (min_timestamp, max_timestamp) or (None, None) if failed """ try: - if timestamp_column in entity_df.columns: - min_ts, max_ts = get_entity_df_timestamp_bounds(entity_df, timestamp_column) - # Convert Pandas Timestamp to datetime if needed - if hasattr(min_ts, "to_pydatetime"): - min_ts = min_ts.to_pydatetime() - if hasattr(max_ts, "to_pydatetime"): - max_ts = max_ts.to_pydatetime() - return min_ts, max_ts + if isinstance(data, Dataset): + # Ray Dataset path - try Ray's built-in operations first + min_ts = data.min(timestamp_column) + max_ts = data.max(timestamp_column) + else: + # Pandas DataFrame path + if timestamp_column in data.columns: + min_ts, max_ts = get_entity_df_timestamp_bounds(data, timestamp_column) + else: + return None, None + + # Convert to datetime if needed + if hasattr(min_ts, "to_pydatetime"): + min_ts = min_ts.to_pydatetime() + elif isinstance(min_ts, pd.Timestamp): + min_ts = min_ts.to_pydatetime() + + if hasattr(max_ts, "to_pydatetime"): + max_ts = max_ts.to_pydatetime() + elif isinstance(max_ts, pd.Timestamp): + max_ts = max_ts.to_pydatetime() + + return min_ts, max_ts except Exception as e: logger.debug( f"Timestamp bounds extraction failed: {e}, falling back to manual calculation" ) - # Fallback to original logic - try: - if timestamp_column in entity_df.columns: - timestamps = pd.to_datetime(entity_df[timestamp_column], utc=True) - return timestamps.min().to_pydatetime(), timestamps.max().to_pydatetime() - except Exception: - pass + # Fallback to manual calculation + try: + if isinstance(data, Dataset): + # Ray Dataset fallback + def extract_bounds(batch: pd.DataFrame) -> pd.DataFrame: + if timestamp_column in batch.columns and not batch.empty: + timestamps = pd.to_datetime(batch[timestamp_column], utc=True) + return pd.DataFrame( + {"min_ts": [timestamps.min()], "max_ts": [timestamps.max()]} + ) + return pd.DataFrame({"min_ts": [None], "max_ts": [None]}) + + bounds_ds = data.map_batches(extract_bounds, batch_format="pandas") + bounds_df = bounds_ds.to_pandas() + + if not bounds_df.empty: + min_ts = bounds_df["min_ts"].min() + max_ts = bounds_df["max_ts"].max() + + if pd.notna(min_ts) and pd.notna(max_ts): + return min_ts.to_pydatetime(), max_ts.to_pydatetime() + else: + # Pandas DataFrame fallback + if timestamp_column in data.columns: + timestamps = pd.to_datetime(data[timestamp_column], utc=True) + return ( + timestamps.min().to_pydatetime(), + timestamps.max().to_pydatetime(), + ) + except Exception: + pass - return None, None + return None, None def _safe_validate_entity_dataframe( - entity_df: pd.DataFrame, + data: Union[pd.DataFrame, Dataset], feature_views: List[FeatureView], project: str, registry: BaseRegistry, ) -> None: """ - Safely validate entity DataFrame using offline_utils with graceful fallback. + Safely validate entity DataFrame or Dataset. + Works with both pandas DataFrames and Ray Datasets. Args: - entity_df: Entity DataFrame to validate + data: DataFrame or Ray Dataset to validate feature_views: List of feature views to validate against project: Feast project name registry: Feature registry @@ -258,23 +352,26 @@ def _safe_validate_entity_dataframe( # Get expected join keys for validation expected_join_keys = get_expected_join_keys(project, feature_views, registry) + dtypes, columns = _get_data_schema_info(data) + # Infer event timestamp column - timestamp_col = infer_event_timestamp_from_entity_df(entity_df.dtypes.to_dict()) + timestamp_col = infer_event_timestamp_from_entity_df(dtypes) - # Validate entity DataFrame has required columns - assert_expected_columns_in_entity_df( - entity_df.dtypes.to_dict(), expected_join_keys, timestamp_col - ) + # Validate DataFrame/Dataset has required columns + assert_expected_columns_in_entity_df(dtypes, expected_join_keys, timestamp_col) + data_type = "Dataset" if isinstance(data, Dataset) else "DataFrame" logger.info( - f"Entity DataFrame validation passed:\n" + f"Entity {data_type} validation passed:\n" f" Expected join keys: {expected_join_keys}\n" - f" Detected timestamp column: {timestamp_col}" + f" Detected timestamp column: {timestamp_col}\n" + f" Available columns: {columns}" ) except Exception as e: # Log validation issues but don't fail - logger.warning(f"Entity DataFrame validation skipped due to error: {e}") + data_type = "Dataset" if isinstance(data, Dataset) else "DataFrame" + logger.warning(f"Entity {data_type} validation skipped due to error: {e}") logger.debug("Validation error details:", exc_info=True) @@ -324,50 +421,55 @@ def _safe_validate_schema( def _convert_feature_column_types( - batch: pd.DataFrame, feature_views: List[FeatureView] -) -> pd.DataFrame: + data: Union[pd.DataFrame, Dataset], feature_views: List[FeatureView] +) -> Union[pd.DataFrame, Dataset]: """ Convert feature columns to appropriate pandas types using Feast's type mapping utilities. + Works with both pandas DataFrames and Ray Datasets. Args: - batch: DataFrame containing feature data + data: DataFrame or Ray Dataset containing feature data feature_views: List of feature views with type information Returns: - DataFrame with properly converted feature column types + DataFrame or Dataset with properly converted feature column types """ - batch = batch.copy() - for fv in feature_views: - for feature in fv.features: - feat_name = feature.name + def convert_batch(batch: pd.DataFrame) -> pd.DataFrame: + batch = batch.copy() - # Check if this feature exists in the batch - if feat_name not in batch.columns: - continue + for fv in feature_views: + for feature in fv.features: + feat_name = feature.name - try: - # Get the Feast ValueType for this feature - value_type = feature.dtype.to_value_type() + # Check if this feature exists in the batch + if feat_name not in batch.columns: + continue - # Handle array/list types - if value_type.name.endswith("_LIST"): - batch[feat_name] = _convert_array_column( - batch[feat_name], value_type - ) - else: - # Handle scalar types using feast type mapping - target_pandas_type = feast_value_type_to_pandas_type(value_type) - batch[feat_name] = _convert_scalar_column( - batch[feat_name], value_type, target_pandas_type + try: + # Get the Feast ValueType for this feature + value_type = feature.dtype.to_value_type() + + # Handle array/list types + if value_type.name.endswith("_LIST"): + batch[feat_name] = _convert_array_column( + batch[feat_name], value_type + ) + else: + # Handle scalar types using feast type mapping + target_pandas_type = feast_value_type_to_pandas_type(value_type) + batch[feat_name] = _convert_scalar_column( + batch[feat_name], value_type, target_pandas_type + ) + + except Exception as e: + logger.warning( + f"Failed to convert feature {feat_name} to proper type: {e}" ) + # Keep original dtype if conversion fails + continue - except Exception as e: - logger.warning( - f"Failed to convert feature {feat_name} to proper type: {e}" - ) - # Keep original dtype if conversion fails - continue + return batch - return batch + return _apply_to_data(data, convert_batch) def _convert_scalar_column( @@ -421,6 +523,25 @@ def convert_array_item(item): return series.apply(convert_array_item) +def _apply_field_mapping( + data: Union[pd.DataFrame, Dataset], field_mapping: Dict[str, str] +) -> Union[pd.DataFrame, Dataset]: + """ + Apply field mapping to column names. + Works with both pandas DataFrames and Ray Datasets. + Args: + data: DataFrame or Ray Dataset to apply mapping to + field_mapping: Dictionary mapping old column names to new column names + Returns: + DataFrame or Dataset with renamed columns + """ + + def rename_columns(df: pd.DataFrame) -> pd.DataFrame: + return df.rename(columns=field_mapping) + + return _apply_to_data(data, rename_columns) + + class RayOfflineStoreConfig(FeastConfigBaseModel): """ Configuration for the Ray Offline Store. @@ -595,7 +716,9 @@ def _manual_point_in_time_join( if feat in features_df.columns and pd.api.types.is_datetime64_any_dtype( features_df[feat] ): - result[feat] = _create_empty_timestamp_column(len(result)) + result[feat] = pd.Series( + [pd.NaT] * len(result), dtype="datetime64[ns, UTC]" + ) else: result[feat] = np.nan @@ -622,14 +745,13 @@ def _manual_point_in_time_join( matching_features = features_df[entity_matches] - if matching_features.empty: - continue - + # Apply time filter if timestamp field exists entity_timestamp = entity_row[timestamp_field] if timestamp_field in matching_features.columns: time_matches = matching_features[timestamp_field] <= entity_timestamp matching_features = matching_features[time_matches] + # Skip if no features match entity criteria or time criteria if matching_features.empty: continue @@ -720,8 +842,8 @@ def join_batch_with_features(batch: pd.DataFrame) -> pd.DataFrame: features_filtered = features[available_feature_cols].copy() # Ensure timestamp columns have compatible dtypes and precision - batch = _normalize_timestamp_column(batch, timestamp_field, inplace=True) - features_filtered = _normalize_timestamp_column( + batch = _normalize_timestamp_columns(batch, timestamp_field, inplace=True) + features_filtered = _normalize_timestamp_columns( features_filtered, timestamp_field, inplace=True ) @@ -958,9 +1080,13 @@ def _add_time_windows_and_source_marker( """Add time windows and source markers to dataset.""" def add_window_and_source(batch: pd.DataFrame) -> pd.DataFrame: - batch = _create_time_window_column( - batch, timestamp_field, window_size, "time_window" - ) + batch = batch.copy() + if timestamp_field in batch.columns: + batch["time_window"] = ( + pd.to_datetime(batch[timestamp_field]) + .dt.floor(window_size) + .astype("datetime64[ns, UTC]") + ) batch["_data_source"] = source_marker return batch @@ -1056,6 +1182,7 @@ def __init__( self._on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None self._feature_refs: List[str] = [] self._entity_df: Optional[pd.DataFrame] = None + self._prefer_ray_datasets: bool = True # New flag to prefer Ray datasets def _create_metadata(self) -> RetrievalMetadata: """Create metadata from the entity DataFrame and feature references.""" @@ -1071,9 +1198,26 @@ def _create_metadata(self) -> RetrievalMetadata: # Get keys (all columns except the detected timestamp column) keys = [col for col in self._entity_df.columns if col != timestamp_col] else: - min_timestamp = None - max_timestamp = None - keys = [] + # Try to extract metadata from Ray dataset if entity_df is not available + try: + result = self._resolve() + if isinstance(result, Dataset): + timestamp_col = _safe_infer_event_timestamp_column( + result, "event_timestamp" + ) + min_timestamp, max_timestamp = _safe_get_entity_timestamp_bounds( + result, timestamp_col + ) + schema = result.schema() + keys = [col for col in schema.names if col != timestamp_col] + else: + min_timestamp = None + max_timestamp = None + keys = [] + except Exception: + min_timestamp = None + max_timestamp = None + keys = [] return RetrievalMetadata( features=self._feature_refs, @@ -1094,12 +1238,27 @@ def _resolve(self) -> Union[Dataset, pd.DataFrame]: result = self._dataset_or_callable return result + def _get_ray_dataset(self) -> Dataset: + """Get the result as a Ray Dataset, converting if necessary.""" + if self._cached_dataset is not None: + return self._cached_dataset + + result = self._resolve() + if isinstance(result, Dataset): + self._cached_dataset = result + return result + elif isinstance(result, pd.DataFrame): + self._cached_dataset = ray.data.from_pandas(result) + return self._cached_dataset + else: + raise ValueError(f"Unsupported result type: {type(result)}") + def to_df( self, validation_reference: Optional[ValidationReference] = None, timeout: Optional[int] = None, ) -> pd.DataFrame: - # Use cached DataFrame if available for repeated access + # Use cached DataFrame if available and no ODFVs if self._cached_df is not None and not self.on_demand_feature_views: df = self._cached_df else: @@ -1113,11 +1272,16 @@ def to_df( validation_reference=validation_reference, timeout=timeout ) else: - result = self._resolve() - if isinstance(result, pd.DataFrame): - df = result + # For Ray datasets, prefer keeping data distributed until the final conversion + if self._prefer_ray_datasets: + ray_ds = self._get_ray_dataset() + df = ray_ds.to_pandas() else: - df = result.to_pandas() + result = self._resolve() + if isinstance(result, pd.DataFrame): + df = result + else: + df = result.to_pandas() self._cached_df = df # Handle validation reference if provided @@ -1148,30 +1312,42 @@ def to_arrow( validation_reference=validation_reference, timeout=timeout ) - # For non-ODFV cases, use direct conversion - result = self._resolve() - if isinstance(result, pd.DataFrame): - return pa.Table.from_pandas(result) - - # For Ray Dataset, use direct Arrow conversion if available - try: - if hasattr(result, "to_arrow"): - return result.to_arrow() - else: + # For Ray datasets, use direct Arrow conversion when available + if self._prefer_ray_datasets: + try: + ray_ds = self._get_ray_dataset() + # Try to use Ray's native to_arrow() if available + if hasattr(ray_ds, "to_arrow"): + return ray_ds.to_arrow() + else: + # Fallback to pandas conversion + df = ray_ds.to_pandas() + return pa.Table.from_pandas(df) + except Exception: # Fallback to pandas conversion - return pa.Table.from_pandas(result.to_pandas()) - except Exception: - # Fallback to pandas conversion - return pa.Table.from_pandas(result.to_pandas()) + df = self.to_df( + validation_reference=validation_reference, timeout=timeout + ) + return pa.Table.from_pandas(df) + else: + # Original implementation for non-Ray datasets + result = self._resolve() + if isinstance(result, pd.DataFrame): + return pa.Table.from_pandas(result) + else: + # For Ray Dataset, convert to pandas first then to arrow + df = result.to_pandas() + return pa.Table.from_pandas(df) def to_remote_storage(self) -> list[str]: if not self._staging_location: raise ValueError("Staging location must be set for remote materialization.") try: - ds = self._resolve() + # Use Ray dataset directly for remote storage + ray_ds = self._get_ray_dataset() RayOfflineStore._ensure_ray_initialized() output_uri = os.path.join(self._staging_location, str(uuid.uuid4())) - ds.write_parquet(output_uri) + ray_ds.write_parquet(output_uri) return [output_uri] except Exception as e: raise RuntimeError(f"Failed to write to remote storage: {e}") @@ -1195,16 +1371,34 @@ def to_sql(self) -> str: raise NotImplementedError("SQL export not supported for Ray offline store") def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: - return self._resolve().to_pandas() + # Use Ray dataset when possible + if self._prefer_ray_datasets: + ray_ds = self._get_ray_dataset() + return ray_ds.to_pandas() + else: + return self._resolve().to_pandas() def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table: - result = self._resolve() - if isinstance(result, pd.DataFrame): - return pa.Table.from_pandas(result) - - # For Ray Dataset, convert to pandas first then to arrow - df = result.to_pandas() - return pa.Table.from_pandas(df) + # Use Ray dataset when possible + if self._prefer_ray_datasets: + ray_ds = self._get_ray_dataset() + try: + if hasattr(ray_ds, "to_arrow"): + return ray_ds.to_arrow() + else: + df = ray_ds.to_pandas() + return pa.Table.from_pandas(df) + except Exception: + df = ray_ds.to_pandas() + return pa.Table.from_pandas(df) + else: + result = self._resolve() + if isinstance(result, pd.DataFrame): + return pa.Table.from_pandas(result) + else: + # For Ray Dataset, convert to pandas first then to arrow + df = result.to_pandas() + return pa.Table.from_pandas(df) def persist( self, @@ -1212,7 +1406,7 @@ def persist( allow_overwrite: Optional[bool] = False, timeout: Optional[int] = None, ) -> str: - """Persist the dataset to storage.""" + """Persist the dataset to storage using Ray operations.""" if not isinstance(storage, SavedDatasetFileStorage): raise ValueError( @@ -1223,24 +1417,60 @@ def persist( if not allow_overwrite and os.path.exists(destination_path): raise SavedDatasetLocationAlreadyExists(location=destination_path) try: - result = self._resolve() + # Use Ray dataset directly for persistence + ray_ds = self._get_ray_dataset() + if not destination_path.startswith(("s3://", "gs://", "hdfs://")): os.makedirs(os.path.dirname(destination_path), exist_ok=True) - # Handle both DataFrame and Ray Dataset - if isinstance(result, pd.DataFrame): - # For DataFrame, convert to Ray Dataset first - RayOfflineStore._ensure_ray_initialized() - ds = ray.data.from_pandas(result) - ds.write_parquet(destination_path) - else: - # For Ray Dataset, use direct write - result.write_parquet(destination_path) + # Use Ray's native write operations + ray_ds.write_parquet(destination_path) return destination_path except Exception as e: raise RuntimeError(f"Failed to persist dataset to {destination_path}: {e}") + def materialize(self) -> None: + """Materialize the Ray dataset to improve subsequent access performance.""" + try: + ray_ds = self._get_ray_dataset() + materialized_ds = ray_ds.materialize() + self._cached_dataset = materialized_ds + logger.info("Ray dataset materialized successfully") + except Exception as e: + logger.warning(f"Failed to materialize Ray dataset: {e}") + + def count(self) -> int: + """Get the number of rows in the dataset efficiently using Ray operations.""" + try: + ray_ds = self._get_ray_dataset() + return ray_ds.count() + except Exception: + # Fallback to pandas + df = self.to_df() + return len(df) + + def take(self, limit: int) -> pd.DataFrame: + """Take a limited number of rows efficiently using Ray operations.""" + try: + ray_ds = self._get_ray_dataset() + limited_ds = ray_ds.limit(limit) + return limited_ds.to_pandas() + except Exception: + # Fallback to pandas + df = self.to_df() + return df.head(limit) + + def schema(self) -> pa.Schema: + """Get the schema of the dataset efficiently using Ray operations.""" + try: + ray_ds = self._get_ray_dataset() + return ray_ds.schema() + except Exception: + # Fallback to pandas + df = self.to_df() + return pa.Table.from_pandas(df).schema + class RayOfflineStore(OfflineStore): def __init__(self): @@ -1248,6 +1478,7 @@ def __init__(self): self._ray_initialized: bool = False self._resource_manager: Optional[RayResourceManager] = None self._data_processor: Optional[RayDataProcessor] = None + self._performance_monitoring: bool = True # Enable performance monitoring @staticmethod def _ensure_ray_initialized(config: Optional[RepoConfig] = None): @@ -1278,16 +1509,35 @@ def _ensure_ray_initialized(config: Optional[RepoConfig] = None): ctx.shuffle_strategy = "sort" # type: ignore ctx.enable_tensor_extension_casting = False + # Log Ray cluster information + if ray.is_initialized(): + cluster_resources = ray.cluster_resources() + logger.info( + f"Ray cluster initialized with {cluster_resources.get('CPU', 0)} CPUs, " + f"{cluster_resources.get('memory', 0) / (1024**3):.1f}GB memory" + ) + def _init_ray(self, config: RepoConfig): ray_config = config.offline_store assert isinstance(ray_config, RayOfflineStoreConfig) - self._ensure_ray_initialized(config) + RayOfflineStore._ensure_ray_initialized(config) if self._resource_manager is None: self._resource_manager = RayResourceManager(ray_config) self._resource_manager.configure_ray_context() if self._data_processor is None: self._data_processor = RayDataProcessor(self._resource_manager) + def _log_performance_metrics( + self, operation: str, dataset_size: int, duration: float + ): + """Log performance metrics for Ray operations.""" + if self._performance_monitoring: + throughput = dataset_size / duration if duration > 0 else 0 + logger.info( + f"Ray {operation} performance: {dataset_size} rows in {duration:.2f}s " + f"({throughput:.0f} rows/s)" + ) + def _get_source_path(self, source: DataSource, config: RepoConfig) -> str: if not isinstance(source, FileSource): raise ValueError("RayOfflineStore currently only supports FileSource") @@ -1295,334 +1545,241 @@ def _get_source_path(self, source: DataSource, config: RepoConfig) -> str: uri = FileSource.get_uri_for_file_path(repo_path, source.path) return uri - @staticmethod - def _create_filtered_dataset( - source_path: str, - timestamp_field: str, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, - ) -> Dataset: - """Helper method to create a filtered dataset based on timestamp range.""" - ds = ray.data.read_parquet(source_path) + def _optimize_dataset_for_operation(self, ds: Dataset, operation: str) -> Dataset: + """Optimize dataset for specific operations.""" + if self._resource_manager is None: + return ds - try: - col_names = ds.schema().names - if timestamp_field not in col_names: - raise ValueError( - f"Timestamp field '{timestamp_field}' not found in columns: {col_names}" - ) - except Exception as e: - raise ValueError(f"Failed to get dataset schema: {e}") + dataset_size = ds.size_bytes() + requirements = self._resource_manager.estimate_processing_requirements( + dataset_size, operation + ) - if start_date or end_date: - try: - if start_date and end_date: + if requirements["can_fit_in_memory"]: + # Materialize small datasets for better performance + ds = ds.materialize() - def filter_func(row): - try: - ts = row[timestamp_field] - return start_date <= ts <= end_date - except KeyError: - raise KeyError( - f"Timestamp field '{timestamp_field}' not found in row. Available keys: {list(row.keys())}" - ) + # Optimize partitioning + optimal_partitions = requirements["optimal_partitions"] + current_partitions = ds.num_blocks() - filtered_ds = ds.filter(filter_func) - elif start_date: + if current_partitions != optimal_partitions: + logger.debug( + f"Repartitioning dataset from {current_partitions} to {optimal_partitions} blocks" + ) + ds = ds.repartition(num_blocks=optimal_partitions) - def filter_func(row): - try: - ts = row[timestamp_field] - return ts >= start_date - except KeyError: - raise KeyError( - f"Timestamp field '{timestamp_field}' not found in row. Available keys: {list(row.keys())}" - ) + return ds - filtered_ds = ds.filter(filter_func) - elif end_date: + def supports_remote_storage_export(self) -> bool: + """Check if remote storage export is supported.""" + return True # Ray supports remote storage natively - def filter_func(row): - try: - ts = row[timestamp_field] - return ts <= end_date - except KeyError: - raise KeyError( - f"Timestamp field '{timestamp_field}' not found in row. Available keys: {list(row.keys())}" - ) + def get_feature_server_endpoint(self) -> Optional[str]: + """Get feature server endpoint if available.""" + return None # Ray offline store doesn't have a feature server endpoint - filtered_ds = ds.filter(filter_func) - else: - return ds + def get_infra_object_names(self) -> List[str]: + """Get infrastructure object names managed by this store.""" + return [] # Ray offline store doesn't manage persistent infrastructure objects - return filtered_ds - except Exception as e: - raise RuntimeError(f"Failed to filter by timestamp: {e}") + def plan_infra(self, config: RepoConfig, desired_registry_proto: Any) -> Any: + """Plan infrastructure changes.""" + # Ray offline store doesn't require infrastructure planning + return None - return ds + def update_infra( + self, + project: str, + tables_to_delete: List[Any], + tables_to_keep: List[Any], + entities_to_delete: List[Any], + entities_to_keep: List[Any], + partial: bool, + ) -> None: + """Update infrastructure.""" + # Ray offline store doesn't require infrastructure updates + pass + + def teardown_infra( + self, project: str, tables: List[Any], entities: List[Any] + ) -> None: + """Teardown infrastructure.""" + # Ray offline store doesn't require infrastructure teardown + pass @staticmethod - def get_historical_features( + def offline_write_batch( config: RepoConfig, - feature_views: List[FeatureView], - feature_refs: List[str], - entity_df: Union[pd.DataFrame, str], - registry: BaseRegistry, - project: str, - full_feature_names: bool = False, - ) -> RetrievalJob: - store = RayOfflineStore() - store._init_ray(config) + feature_view: FeatureView, + table: pa.Table, + progress: Optional[Callable[[int], Any]] = None, + ) -> None: + """Write batch data using Ray operations with performance monitoring.""" + import time - # Load entity_df as Ray dataset for distributed processing - if isinstance(entity_df, str): - entity_ds = ray.data.read_csv(entity_df) - original_entity_df = pd.read_csv(entity_df) - else: - entity_ds = ray.data.from_pandas(entity_df) - original_entity_df = entity_df.copy() + start_time = time.time() - # Make entity dataframe timezone aware and normalize timestamp - original_entity_df = _ensure_timestamp_compatibility( - original_entity_df, ["event_timestamp"] - ) + RayOfflineStore._ensure_ray_initialized(config) - # Parse feature_refs and get ODFVs - on_demand_feature_views = OnDemandFeatureView.get_requested_odfvs( - feature_refs, project, registry - ) + repo_path = getattr(config, "repo_path", None) or os.getcwd() + ray_config = config.offline_store + assert isinstance(ray_config, RayOfflineStoreConfig) + assert isinstance(feature_view.batch_source, FileSource) - # Validate request data for ODFVs - for odfv in on_demand_feature_views: - odfv_request_data_schema = odfv.get_request_data_schema() - for feature_name in odfv_request_data_schema.keys(): - if feature_name not in original_entity_df.columns: - raise RequestDataNotFoundInEntityDfException( - feature_name=feature_name, - feature_view_name=odfv.name, - ) + # Enhanced schema validation using safe utility + validation_result = _safe_validate_schema( + config, feature_view.batch_source, table.column_names, "offline_write_batch" + ) - # Filter out on-demand feature views from regular feature views - # ODFVs don't have data sources and are computed from base features - odfv_names = {odfv.name for odfv in on_demand_feature_views} - regular_feature_views = [ - fv for fv in feature_views if fv.name not in odfv_names - ] - - # Enhanced validation using offline_utils with safe fallback - _safe_validate_entity_dataframe( - original_entity_df, regular_feature_views, project, registry - ) - # Apply field mappings to entity dataset if needed - global_field_mappings = {} - for fv in regular_feature_views: - mapping = getattr(fv.batch_source, "field_mapping", None) - if mapping: - for k, v in mapping.items(): - global_field_mappings[v] = k - - if global_field_mappings: - cols_to_rename = { - v: k - for k, v in global_field_mappings.items() - if v in original_entity_df.columns - } - if cols_to_rename: - entity_ds = entity_ds.map_batches( - lambda batch: batch.rename(columns=cols_to_rename), - batch_format="pandas", - ) + if validation_result: + expected_schema, expected_columns = validation_result + # Try to reorder columns to match expected order if needed + if expected_columns != table.column_names and set(expected_columns) == set( + table.column_names + ): + logger.info("Reordering table columns to match expected schema") + table = table.select(expected_columns) - # Start with entity dataset - result_ds = entity_ds + batch_source_path = feature_view.batch_source.file_options.uri + feature_path = FileSource.get_uri_for_file_path(repo_path, batch_source_path) - # Process each regular feature view with intelligent join strategy - for fv in regular_feature_views: - fv_feature_refs = [ - ref - for ref in feature_refs - if ref.startswith(fv.projection.name_to_use() + ":") - ] - if not fv_feature_refs: - continue - # Get join configuration - entities = fv.entities or [] - entity_objs = [registry.get_entity(e, project) for e in entities] - original_join_keys, _, timestamp_field, created_col = _get_column_names( - fv, entity_objs - ) + # Use Ray Dataset for efficient writing + ds = ray.data.from_arrow(table) - # Apply join key mapping from projection if present - if fv.projection.join_key_map: - join_keys = [ - fv.projection.join_key_map.get(key, key) - for key in original_join_keys - ] + try: + # If the path points to a file, write directly to that file location + # If it points to a directory, write to that directory + if feature_path.endswith(".parquet"): + # For single file writes, check if file exists and append if it does + if os.path.exists(feature_path): + # Read existing data as Ray Dataset + existing_ds = ray.data.read_parquet(feature_path) + # Append new data using Ray operations + combined_ds = existing_ds.union(ds) + # Write combined data + combined_ds.write_parquet(feature_path) + else: + # Write new data + ds.write_parquet(feature_path) else: - join_keys = original_join_keys - - # Extract requested features - requested_feats = [ref.split(":", 1)[1] for ref in fv_feature_refs] - - # Validate requested features exist - available_feature_names = [f.name for f in fv.features] - missing_feats = [ - f for f in requested_feats if f not in available_feature_names - ] - if missing_feats: - raise KeyError( - f"Requested features {missing_feats} not found in feature view '{fv.name}' " - f"(available: {available_feature_names})" - ) - - # Load feature data as Ray dataset - source_path = store._get_source_path(fv.batch_source, config) - feature_ds = ray.data.read_parquet(source_path) - feature_size = feature_ds.size_bytes() - - # Apply field mapping to feature dataset if needed - field_mapping = getattr(fv.batch_source, "field_mapping", None) - if field_mapping: - feature_ds = feature_ds.map_batches( - lambda batch: batch.rename(columns=field_mapping), - batch_format="pandas", - ) - # Update join keys and timestamp field to mapped names - join_keys = [field_mapping.get(k, k) for k in join_keys] - timestamp_field = field_mapping.get(timestamp_field, timestamp_field) - if created_col: - created_col = field_mapping.get(created_col, created_col) - - # Ensure timestamp compatibility in entity dataset - if ( - timestamp_field != "event_timestamp" - and timestamp_field not in original_entity_df.columns - and "event_timestamp" in original_entity_df.columns - ): - - def add_timestamp_field(batch: pd.DataFrame) -> pd.DataFrame: - batch = batch.copy() - batch[timestamp_field] = batch["event_timestamp"] - return _normalize_timestamp_column( - batch, timestamp_field, inplace=True - ) + # Write to directory (multiple parquet files) + os.makedirs(feature_path, exist_ok=True) + ds.write_parquet(feature_path) - result_ds = result_ds.map_batches( - add_timestamp_field, batch_format="pandas" - ) + # Call progress callback if provided + if progress: + progress(table.num_rows) - # Determine join strategy based on dataset sizes and cluster resources - if store._resource_manager is None: - raise ValueError("Resource manager not initialized") - requirements = store._resource_manager.estimate_processing_requirements( - feature_size, "join" - ) + except Exception as e: + logger.error(f"Failed to write batch data: {e}") + # Fallback to pandas-based writing + logger.info("Falling back to pandas-based writing") - if requirements["should_broadcast"]: - # Use broadcast join for small feature datasets - logger.info( - f"Using broadcast join for {fv.name} (size: {feature_size // 1024**2}MB)" - ) - feature_df = feature_ds.to_pandas() - feature_df = _ensure_timestamp_compatibility( - feature_df, [timestamp_field] - ) + # Convert to pandas for fallback + df = table.to_pandas() - if store._data_processor is None: - raise ValueError("Data processor not initialized") - result_ds = store._data_processor.broadcast_join_features( - result_ds, - feature_df, - join_keys, - timestamp_field, - requested_feats, - full_feature_names, - fv.projection.name_to_use(), - original_join_keys if fv.projection.join_key_map else None, - ) + if feature_path.endswith(".parquet"): + # Check if file exists and append if it does + if os.path.exists(feature_path): + # Read existing data + existing_df = pd.read_parquet(feature_path) + # Append new data + combined_df = pd.concat([existing_df, df], ignore_index=True) + # Write combined data + combined_df.to_parquet(feature_path, index=False) + else: + # Write new data + df.to_parquet(feature_path, index=False) else: - # Use distributed windowed join for large feature datasets - logger.info( - f"Using distributed join for {fv.name} (size: {feature_size // 1024**2}MB)" - ) + # Write to directory (multiple parquet files) + os.makedirs(feature_path, exist_ok=True) - # Ensure timestamp format in feature dataset - def normalize_timestamps(batch: pd.DataFrame) -> pd.DataFrame: - return _ensure_timestamp_compatibility(batch, [timestamp_field]) + # Convert to Ray dataset and write + ds_fallback = ray.data.from_pandas(df) + ds_fallback.write_parquet(feature_path) - feature_ds = feature_ds.map_batches( - normalize_timestamps, batch_format="pandas" - ) - - if store._data_processor is None: - raise ValueError("Data processor not initialized") - result_ds = store._data_processor.windowed_temporal_join( - result_ds, - feature_ds, - join_keys, - timestamp_field, - requested_feats, - window_size=config.offline_store.window_size_for_joins, - full_feature_names=full_feature_names, - feature_view_name=fv.projection.name_to_use(), - original_join_keys=original_join_keys - if fv.projection.join_key_map - else None, - ) + # Call progress callback if provided + if progress: + progress(table.num_rows) - # Final processing: clean up and ensure proper column structure - def finalize_result(batch: pd.DataFrame) -> pd.DataFrame: - batch = batch.copy() + # Log performance metrics + duration = time.time() - start_time + logger.info( + f"Ray offline_write_batch performance: {table.num_rows} rows in {duration:.2f}s " + f"({table.num_rows / duration:.0f} rows/s)" + ) - # Preserve existing feature columns (including renamed ones) - existing_columns = set(batch.columns) + def online_write_batch( + self, + config: RepoConfig, + table: pa.Table, + progress: Optional[Callable[[int], Any]] = None, + ) -> None: + """Ray offline store doesn't support online writes.""" + raise NotImplementedError("Ray offline store doesn't support online writes") - # Re-attach any missing original entity columns that aren't already present - for col in original_entity_df.columns: - if col not in existing_columns: - # For missing columns, use values from original entity df - if len(batch) <= len(original_entity_df): - batch[col] = original_entity_df[col].iloc[: len(batch)].values - else: - # Repeat values if batch is larger - repeated_values = np.tile( - original_entity_df[col].values, - (len(batch) // len(original_entity_df) + 1), - ) - batch[col] = repeated_values[: len(batch)] + def get_table_query_string(self) -> str: + """Get table query string format.""" + return "file://{table_name}" - # Ensure event_timestamp is present - if "event_timestamp" not in batch.columns: - if "event_timestamp" in original_entity_df.columns: - batch["event_timestamp"] = ( - original_entity_df["event_timestamp"].iloc[: len(batch)].values - ) - batch = _normalize_timestamp_column( - batch, "event_timestamp", inplace=True - ) - elif timestamp_field in batch.columns: - batch["event_timestamp"] = batch[timestamp_field] + def get_table_column_names_and_types( + self, config: RepoConfig, data_source: DataSource + ) -> Iterable[Tuple[str, str]]: + """Get table column names and types efficiently using Ray.""" + return self.get_table_column_names_and_types_from_data_source( + config, data_source + ) - # Fix data types for feature columns using centralized type mapping utilities - batch = _convert_feature_column_types(batch, regular_feature_views) + def create_ray_dataset_from_table( + self, config: RepoConfig, data_source: DataSource + ) -> Dataset: + """Create a Ray Dataset from a data source.""" + self._init_ray(config) + source_path = self._get_source_path(data_source, config) + ds = ray.data.read_parquet(source_path) - return batch + # Apply field mapping if needed + field_mapping = getattr(data_source, "field_mapping", None) + if field_mapping: + ds = _apply_field_mapping(ds, field_mapping) - result_ds = result_ds.map_batches(finalize_result, batch_format="pandas") + return ds - # Storage path validation - storage_path = config.offline_store.storage_path - if not storage_path: - raise ValueError("Storage path must be set in config") + def get_dataset_statistics(self, ds: Dataset) -> Dict[str, Any]: + """Get comprehensive statistics for a Ray Dataset.""" + try: + stats = { + "num_rows": ds.count(), + "num_blocks": ds.num_blocks(), + "size_bytes": ds.size_bytes(), + "schema": ds.schema(), + } - # Create retrieval job following standard pattern - job = RayRetrievalJob(result_ds, staging_location=storage_path) - job._full_feature_names = full_feature_names - job._on_demand_feature_views = on_demand_feature_views - job._feature_refs = feature_refs - job._entity_df = original_entity_df - job._metadata = job._create_metadata() - return job + # Add column statistics if possible + try: + column_stats = {} + for col in ds.schema().names: + try: + column_stats[col] = { + "min": ds.min(col), + "max": ds.max(col), + "mean": ds.mean(col) + if ds.schema().field(col).type + in [pa.float32(), pa.float64(), pa.int32(), pa.int64()] + else None, + } + except Exception: + # Skip columns that don't support these operations + pass + stats["column_stats"] = column_stats + except Exception: + pass + + return stats + except Exception as e: + logger.warning(f"Failed to get dataset statistics: {e}") + return {"error": str(e)} def validate_data_source( self, @@ -1641,10 +1798,6 @@ def get_table_column_names_and_types_from_data_source( """Returns the list of column names and raw column types for a DataSource.""" return data_source.get_table_column_names_and_types(config=config) - def supports_remote_storage_export(self) -> bool: - """Check if remote storage export is supported.""" - return self._staging_location is not None - @staticmethod def _load_and_filter_dataset( source_path: str, @@ -1709,24 +1862,16 @@ def _load_and_filter_dataset( # Handle empty DataFrame case if df.empty: - empty_columns = ( - join_key_columns + feature_name_columns + timestamp_columns - ) - if not join_key_columns: - empty_columns.append(DUMMY_ENTITY_ID) - if "event_timestamp" not in empty_columns: - empty_columns.append("event_timestamp") - return _create_empty_dataframe_with_timestamp_columns( - empty_columns, timestamp_columns + return _handle_empty_dataframe_case( + join_key_columns, feature_name_columns, timestamp_columns ) # Build required columns list - all_required_columns = ( - join_key_columns + feature_name_columns + timestamp_columns + all_required_columns = _build_required_columns( + join_key_columns, feature_name_columns, timestamp_columns ) if not join_key_columns: df[DUMMY_ENTITY_ID] = DUMMY_ENTITY_VAL - all_required_columns.append(DUMMY_ENTITY_ID) # Select only the required columns that exist available_columns = [ @@ -1758,76 +1903,239 @@ def _load_and_filter_dataset( raise RuntimeError(f"Failed to load data from {source_path}: {e}") @staticmethod - def pull_latest_from_table_or_query( - config: RepoConfig, + def _load_and_filter_dataset_ray( + source_path: str, data_source: DataSource, join_key_columns: List[str], feature_name_columns: List[str], timestamp_field: str, created_timestamp_column: Optional[str], - start_date: datetime, - end_date: datetime, - ) -> RetrievalJob: - store = RayOfflineStore() - store._init_ray(config) - - source_path = store._get_source_path(data_source, config) + start_date: Optional[datetime], + end_date: Optional[datetime], + ) -> Dataset: + """ + Ray-native method to load and filter dataset for distributed processing. + Args: + source_path: Path to the data source + data_source: DataSource object containing field mapping + join_key_columns: List of join key columns + feature_name_columns: List of feature columns + timestamp_field: Name of the timestamp field + created_timestamp_column: Optional created timestamp column + start_date: Optional start date for filtering + end_date: Optional end date for filtering + Returns: + Processed Ray Dataset + """ + try: + # Get field mapping for column renaming after loading + field_mapping = getattr(data_source, "field_mapping", None) - def _load(): - # Load and filter the dataset using the shared method - df = store._load_and_filter_dataset( - source_path, - data_source, - join_key_columns, - feature_name_columns, - timestamp_field, - created_timestamp_column, - start_date, - end_date, + # Load and filter the dataset using the original timestamp field name + ds = RayOfflineStore._create_filtered_dataset( + source_path, timestamp_field, start_date, end_date ) - # Handle deduplication (keep latest records) - specific to pull_latest - if join_key_columns and not df.empty: - # Get field mapping for proper column names - field_mapping = getattr(data_source, "field_mapping", None) - timestamp_field_mapped = ( - field_mapping.get(timestamp_field, timestamp_field) - if field_mapping - else timestamp_field - ) - created_timestamp_column_mapped = ( - field_mapping.get( - created_timestamp_column, created_timestamp_column - ) - if field_mapping and created_timestamp_column - else created_timestamp_column - ) - - # Build timestamp columns for sorting - timestamp_columns = [timestamp_field_mapped] - if created_timestamp_column_mapped: - timestamp_columns.append(created_timestamp_column_mapped) - - # Filter out timestamp columns that don't exist in the dataframe - existing_timestamp_columns = [ - col for col in timestamp_columns if col in df.columns - ] - - # Sort by join keys (ascending) and timestamps (descending for latest first) - sort_columns = join_key_columns + existing_timestamp_columns - if sort_columns: - df = df.sort_values( - sort_columns, - ascending=[True] * len(join_key_columns) - + [False] * len(existing_timestamp_columns), - ) - df = df.drop_duplicates(subset=join_key_columns, keep="first") - - return df - - return RayRetrievalJob( - _load, staging_location=config.offline_store.storage_path - ) + # Apply field mapping if needed using Ray operations + if field_mapping: + ds = _apply_field_mapping(ds, field_mapping) + + # Get mapped field names + timestamp_field_mapped = ( + field_mapping.get(timestamp_field, timestamp_field) + if field_mapping + else timestamp_field + ) + created_timestamp_column_mapped = ( + field_mapping.get(created_timestamp_column, created_timestamp_column) + if field_mapping and created_timestamp_column + else created_timestamp_column + ) + + # Build timestamp columns list + timestamp_columns = [timestamp_field_mapped] + if created_timestamp_column_mapped: + timestamp_columns.append(created_timestamp_column_mapped) + + # Normalize timestamp columns using Ray operations + ds = _normalize_timestamp_columns(ds, timestamp_columns) + + # Process dataset using Ray operations + def process_batch(batch: pd.DataFrame) -> pd.DataFrame: + # Apply timezone awareness + batch = make_df_tzaware(batch) + + # Handle empty batch case + if batch.empty: + return _handle_empty_dataframe_case( + join_key_columns, feature_name_columns, timestamp_columns + ) + + # Build required columns list + all_required_columns = _build_required_columns( + join_key_columns, feature_name_columns, timestamp_columns + ) + if not join_key_columns: + batch[DUMMY_ENTITY_ID] = DUMMY_ENTITY_VAL + + # Select only the required columns that exist + available_columns = [ + col for col in all_required_columns if col in batch.columns + ] + batch = batch[available_columns] + + # Ensure 'event_timestamp' column exists for pandas backend compatibility + if ( + "event_timestamp" not in batch.columns + and timestamp_field_mapped != "event_timestamp" + ): + if timestamp_field_mapped in batch.columns: + batch["event_timestamp"] = batch[timestamp_field_mapped] + + return batch + + ds = ds.map_batches(process_batch, batch_format="pandas") + + # Sort by timestamp (most recent first) using Ray operations + timestamp_columns_existing = [ + col for col in timestamp_columns if col in ds.schema().names + ] + if timestamp_columns_existing: + # Sort using Ray's native sorting + ds = ds.sort(timestamp_columns_existing, descending=True) + + return ds + + except Exception as e: + raise RuntimeError(f"Failed to load data from {source_path}: {e}") + + @staticmethod + def _pull_latest_processing_ray( + ds: Dataset, + join_key_columns: List[str], + timestamp_field: str, + created_timestamp_column: Optional[str], + field_mapping: Optional[Dict[str, str]] = None, + ) -> Dataset: + """ + Ray-native processing for pull_latest operations with deduplication. + Args: + ds: Ray Dataset to process + join_key_columns: List of join key columns + timestamp_field: Name of the timestamp field + created_timestamp_column: Optional created timestamp column + field_mapping: Optional field mapping dictionary + Returns: + Ray Dataset with latest records only + """ + if not join_key_columns: + return ds + + # Get mapped field names + timestamp_field_mapped = ( + field_mapping.get(timestamp_field, timestamp_field) + if field_mapping + else timestamp_field + ) + created_timestamp_column_mapped = ( + field_mapping.get(created_timestamp_column, created_timestamp_column) + if field_mapping and created_timestamp_column + else created_timestamp_column + ) + + # Build timestamp columns for sorting + timestamp_columns = [timestamp_field_mapped] + if created_timestamp_column_mapped: + timestamp_columns.append(created_timestamp_column_mapped) + + def deduplicate_batch(batch: pd.DataFrame) -> pd.DataFrame: + if batch.empty: + return batch + + # Filter out timestamp columns that don't exist in the dataframe + existing_timestamp_columns = [ + col for col in timestamp_columns if col in batch.columns + ] + + # Sort by join keys (ascending) and timestamps (descending for latest first) + sort_columns = join_key_columns + existing_timestamp_columns + if sort_columns: + batch = batch.sort_values( + sort_columns, + ascending=[True] * len(join_key_columns) + + [False] * len(existing_timestamp_columns), + ) + batch = batch.drop_duplicates(subset=join_key_columns, keep="first") + + return batch + + return ds.map_batches(deduplicate_batch, batch_format="pandas") + + @staticmethod + def pull_latest_from_table_or_query( + config: RepoConfig, + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_field: str, + created_timestamp_column: Optional[str], + start_date: datetime, + end_date: datetime, + ) -> RetrievalJob: + store = RayOfflineStore() + store._init_ray(config) + + source_path = store._get_source_path(data_source, config) + + def _load_ray_dataset(): + # Use Ray-native processing for better performance + ds = store._load_and_filter_dataset_ray( + source_path, + data_source, + join_key_columns, + feature_name_columns, + timestamp_field, + created_timestamp_column, + start_date, + end_date, + ) + + # Apply pull_latest processing (deduplication) using Ray operations + field_mapping = getattr(data_source, "field_mapping", None) + ds = store._pull_latest_processing_ray( + ds, + join_key_columns, + timestamp_field, + created_timestamp_column, + field_mapping, + ) + + return ds + + def _load_pandas_fallback(): + # Fallback to pandas processing for compatibility + return store._load_and_filter_dataset( + source_path, + data_source, + join_key_columns, + feature_name_columns, + timestamp_field, + created_timestamp_column, + start_date, + end_date, + ) + + # Try Ray-native processing first, fallback to pandas if needed + try: + return RayRetrievalJob( + _load_ray_dataset, staging_location=config.offline_store.storage_path + ) + except Exception as e: + logger.warning(f"Ray-native processing failed: {e}, falling back to pandas") + return RayRetrievalJob( + _load_pandas_fallback, + staging_location=config.offline_store.storage_path, + ) @staticmethod def pull_all_from_table_or_query( @@ -1849,7 +2157,21 @@ def pull_all_from_table_or_query( if not fs.exists(path_in_fs): raise FileNotFoundError(f"Parquet path does not exist: {source_path}") - def _load(): + def _load_ray_dataset(): + # Use Ray-native processing for better performance + return store._load_and_filter_dataset_ray( + source_path, + data_source, + join_key_columns, + feature_name_columns, + timestamp_field, + created_timestamp_column, + start_date, + end_date, + ) + + def _load_pandas_fallback(): + # Fallback to pandas processing for compatibility return store._load_and_filter_dataset( source_path, data_source, @@ -1861,9 +2183,17 @@ def _load(): end_date, ) - return RayRetrievalJob( - _load, staging_location=config.offline_store.storage_path - ) + # Try Ray-native processing first, fallback to pandas if needed + try: + return RayRetrievalJob( + _load_ray_dataset, staging_location=config.offline_store.storage_path + ) + except Exception as e: + logger.warning(f"Ray-native processing failed: {e}, falling back to pandas") + return RayRetrievalJob( + _load_pandas_fallback, + staging_location=config.offline_store.storage_path, + ) @staticmethod def write_logged_features( @@ -1885,90 +2215,357 @@ def write_logged_features( path = FileSource.get_uri_for_file_path(repo_path, source_path) try: + # Use Ray dataset for efficient writing if isinstance(data, Path): ds = ray.data.read_parquet(str(data)) else: - ds = ray.data.from_pandas(pa.Table.to_pandas(data)) + # Convert PyArrow Table to Ray Dataset directly + ds = ray.data.from_arrow(data) - ds.materialize() + # Materialize for better performance + ds = ds.materialize() if not path.startswith(("s3://", "gs://")): os.makedirs(os.path.dirname(path), exist_ok=True) + # Use Ray's native write operations ds.write_parquet(path) except Exception as e: raise RuntimeError(f"Failed to write logged features: {e}") @staticmethod - def offline_write_batch( + def create_saved_dataset_destination( config: RepoConfig, - feature_view: FeatureView, - table: pa.Table, - progress: Optional[Callable[[int], Any]] = None, - ) -> None: - RayOfflineStore._ensure_ray_initialized(config) + name: str, + path: Optional[str] = None, + ) -> SavedDatasetStorage: + """Create a saved dataset destination for Ray offline store.""" - repo_path = getattr(config, "repo_path", None) or os.getcwd() - ray_config = config.offline_store - assert isinstance(ray_config, RayOfflineStoreConfig) - assert isinstance(feature_view.batch_source, FileSource) + if path is None: + ray_config = config.offline_store + assert isinstance(ray_config, RayOfflineStoreConfig) + base_storage_path = ray_config.storage_path or "/tmp/ray-storage" + path = f"{base_storage_path}/saved_datasets/{name}.parquet" - # Enhanced schema validation using safe utility - validation_result = _safe_validate_schema( - config, feature_view.batch_source, table.column_names, "offline_write_batch" + return SavedDatasetFileStorage(path=path) + + @staticmethod + def _create_filtered_dataset( + source_path: str, + timestamp_field: str, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + ) -> Dataset: + """Helper method to create a filtered dataset based on timestamp range.""" + ds = ray.data.read_parquet(source_path) + + try: + col_names = ds.schema().names + if timestamp_field not in col_names: + raise ValueError( + f"Timestamp field '{timestamp_field}' not found in columns: {col_names}" + ) + except Exception as e: + raise ValueError(f"Failed to get dataset schema: {e}") + + if start_date or end_date: + try: + if start_date and end_date: + + def filter_func(row): + try: + ts = row[timestamp_field] + return start_date <= ts <= end_date + except KeyError: + raise KeyError( + f"Timestamp field '{timestamp_field}' not found in row. Available keys: {list(row.keys())}" + ) + + filtered_ds = ds.filter(filter_func) + elif start_date: + + def filter_func(row): + try: + ts = row[timestamp_field] + return ts >= start_date + except KeyError: + raise KeyError( + f"Timestamp field '{timestamp_field}' not found in row. Available keys: {list(row.keys())}" + ) + + filtered_ds = ds.filter(filter_func) + elif end_date: + + def filter_func(row): + try: + ts = row[timestamp_field] + return ts <= end_date + except KeyError: + raise KeyError( + f"Timestamp field '{timestamp_field}' not found in row. Available keys: {list(row.keys())}" + ) + + filtered_ds = ds.filter(filter_func) + else: + return ds + + return filtered_ds + except Exception as e: + raise RuntimeError(f"Failed to filter by timestamp: {e}") + + return ds + + @staticmethod + def get_historical_features( + config: RepoConfig, + feature_views: List[FeatureView], + feature_refs: List[str], + entity_df: Union[pd.DataFrame, str], + registry: BaseRegistry, + project: str, + full_feature_names: bool = False, + ) -> RetrievalJob: + store = RayOfflineStore() + store._init_ray(config) + + # Load entity_df as Ray dataset for distributed processing + if isinstance(entity_df, str): + entity_ds = ray.data.read_csv(entity_df) + # Keep a minimal pandas copy only for metadata creation + entity_df_sample = entity_ds.limit(1000).to_pandas() + else: + entity_ds = ray.data.from_pandas(entity_df) + entity_df_sample = entity_df.copy() + + # Make entity dataset timezone aware and normalize timestamp using Ray operations + entity_ds = _ensure_timestamp_compatibility(entity_ds, ["event_timestamp"]) + + # Parse feature_refs and get ODFVs + on_demand_feature_views = OnDemandFeatureView.get_requested_odfvs( + feature_refs, project, registry ) - if validation_result: - expected_schema, expected_columns = validation_result - # Try to reorder columns to match expected order if needed - if expected_columns != table.column_names and set(expected_columns) == set( - table.column_names + # Validate request data for ODFVs using sample + for odfv in on_demand_feature_views: + odfv_request_data_schema = odfv.get_request_data_schema() + for feature_name in odfv_request_data_schema.keys(): + if feature_name not in entity_df_sample.columns: + raise RequestDataNotFoundInEntityDfException( + feature_name=feature_name, + feature_view_name=odfv.name, + ) + + # Filter out on-demand feature views from regular feature views + # ODFVs don't have data sources and are computed from base features + odfv_names = {odfv.name for odfv in on_demand_feature_views} + regular_feature_views = [ + fv for fv in feature_views if fv.name not in odfv_names + ] + + # Enhanced validation using unified operations + _safe_validate_entity_dataframe( + entity_ds, regular_feature_views, project, registry + ) + + # Apply field mappings to entity dataset if needed using unified operations + global_field_mappings = {} + for fv in regular_feature_views: + mapping = getattr(fv.batch_source, "field_mapping", None) + if mapping: + for k, v in mapping.items(): + global_field_mappings[v] = k + + if global_field_mappings: + cols_to_rename = { + v: k + for k, v in global_field_mappings.items() + if v in entity_df_sample.columns + } + if cols_to_rename: + entity_ds = _apply_field_mapping(entity_ds, cols_to_rename) + + # Start with entity dataset - keep it as Ray dataset throughout + result_ds = entity_ds + + # Process each regular feature view with intelligent join strategy + for fv in regular_feature_views: + fv_feature_refs = [ + ref + for ref in feature_refs + if ref.startswith(fv.projection.name_to_use() + ":") + ] + if not fv_feature_refs: + continue + + # Get join configuration + entities = fv.entities or [] + entity_objs = [registry.get_entity(e, project) for e in entities] + original_join_keys, _, timestamp_field, created_col = _get_column_names( + fv, entity_objs + ) + + # Apply join key mapping from projection if present + if fv.projection.join_key_map: + join_keys = [ + fv.projection.join_key_map.get(key, key) + for key in original_join_keys + ] + else: + join_keys = original_join_keys + + # Extract requested features + requested_feats = [ref.split(":", 1)[1] for ref in fv_feature_refs] + + # Validate requested features exist + available_feature_names = [f.name for f in fv.features] + missing_feats = [ + f for f in requested_feats if f not in available_feature_names + ] + if missing_feats: + raise KeyError( + f"Requested features {missing_feats} not found in feature view '{fv.name}' " + f"(available: {available_feature_names})" + ) + + # Load feature data as Ray dataset + source_path = store._get_source_path(fv.batch_source, config) + feature_ds = ray.data.read_parquet(source_path) + feature_size = feature_ds.size_bytes() + + # Apply field mapping to feature dataset if needed using unified operations + field_mapping = getattr(fv.batch_source, "field_mapping", None) + if field_mapping: + feature_ds = _apply_field_mapping(feature_ds, field_mapping) + # Update join keys and timestamp field to mapped names + join_keys = [field_mapping.get(k, k) for k in join_keys] + timestamp_field = field_mapping.get(timestamp_field, timestamp_field) + if created_col: + created_col = field_mapping.get(created_col, created_col) + + # Ensure timestamp compatibility in entity dataset using unified operations + if ( + timestamp_field != "event_timestamp" + and timestamp_field not in entity_df_sample.columns + and "event_timestamp" in entity_df_sample.columns ): - logger.info("Reordering table columns to match expected schema") - table = table.select(expected_columns) - batch_source_path = feature_view.batch_source.file_options.uri - feature_path = FileSource.get_uri_for_file_path(repo_path, batch_source_path) + def add_timestamp_field(batch: pd.DataFrame) -> pd.DataFrame: + batch = batch.copy() + batch[timestamp_field] = batch["event_timestamp"] + return batch - # If the path points to a file, write directly to that file location - # If it points to a directory, write to that directory - if feature_path.endswith(".parquet"): - # Convert PyArrow table to pandas DataFrame - df = table.to_pandas() + result_ds = result_ds.map_batches( + add_timestamp_field, batch_format="pandas" + ) + result_ds = _normalize_timestamp_columns(result_ds, timestamp_field) + + # Determine join strategy based on dataset sizes and cluster resources + if store._resource_manager is None: + raise ValueError("Resource manager not initialized") + requirements = store._resource_manager.estimate_processing_requirements( + feature_size, "join" + ) + + if requirements["should_broadcast"]: + # Use broadcast join for small feature datasets + logger.info( + f"Using broadcast join for {fv.name} (size: {feature_size // 1024**2}MB)" + ) + # Convert to pandas only for broadcast join + feature_df = feature_ds.to_pandas() + feature_df = _ensure_timestamp_compatibility( + feature_df, [timestamp_field] + ) - # Check if file exists and append if it does - if os.path.exists(feature_path): - # Read existing data - existing_df = pd.read_parquet(feature_path) - # Append new data - combined_df = pd.concat([existing_df, df], ignore_index=True) - # Write combined data - combined_df.to_parquet(feature_path, index=False) + if store._data_processor is None: + raise ValueError("Data processor not initialized") + result_ds = store._data_processor.broadcast_join_features( + result_ds, + feature_df, + join_keys, + timestamp_field, + requested_feats, + full_feature_names, + fv.projection.name_to_use(), + original_join_keys if fv.projection.join_key_map else None, + ) else: - # Write new data - df.to_parquet(feature_path, index=False) - else: - # Write to directory (multiple parquet files) - os.makedirs(feature_path, exist_ok=True) + # Use distributed windowed join for large feature datasets + logger.info( + f"Using distributed join for {fv.name} (size: {feature_size // 1024**2}MB)" + ) - # Convert PyArrow table to Ray dataset - ds = ray.data.from_arrow(table) + # Ensure timestamp format in feature dataset using unified operations + feature_ds = _ensure_timestamp_compatibility( + feature_ds, [timestamp_field] + ) - # Write to parquet - ds.write_parquet(feature_path) + if store._data_processor is None: + raise ValueError("Data processor not initialized") + result_ds = store._data_processor.windowed_temporal_join( + result_ds, + feature_ds, + join_keys, + timestamp_field, + requested_feats, + window_size=config.offline_store.window_size_for_joins, + full_feature_names=full_feature_names, + feature_view_name=fv.projection.name_to_use(), + original_join_keys=original_join_keys + if fv.projection.join_key_map + else None, + ) - @staticmethod - def create_saved_dataset_destination( - config: RepoConfig, - name: str, - path: Optional[str] = None, - ) -> SavedDatasetStorage: - """Create a saved dataset destination for Ray offline store.""" + # Final processing: clean up and ensure proper column structure using Ray operations + def finalize_result(batch: pd.DataFrame) -> pd.DataFrame: + batch = batch.copy() - if path is None: - ray_config = config.offline_store - assert isinstance(ray_config, RayOfflineStoreConfig) - base_storage_path = ray_config.storage_path or "/tmp/ray-storage" - path = f"{base_storage_path}/saved_datasets/{name}.parquet" + # Preserve existing feature columns (including renamed ones) + existing_columns = set(batch.columns) - return SavedDatasetFileStorage(path=path) + # Re-attach any missing original entity columns that aren't already present + for col in entity_df_sample.columns: + if col not in existing_columns: + # For missing columns, use values from entity df sample + if len(batch) <= len(entity_df_sample): + batch[col] = entity_df_sample[col].iloc[: len(batch)].values + else: + # Repeat values if batch is larger + repeated_values = np.tile( + entity_df_sample[col].values, + (len(batch) // len(entity_df_sample) + 1), + ) + batch[col] = repeated_values[: len(batch)] + + # Ensure event_timestamp is present + if "event_timestamp" not in batch.columns: + if "event_timestamp" in entity_df_sample.columns: + batch["event_timestamp"] = ( + entity_df_sample["event_timestamp"].iloc[: len(batch)].values + ) + batch = _normalize_timestamp_columns( + batch, "event_timestamp", inplace=True + ) + elif timestamp_field in batch.columns: + batch["event_timestamp"] = batch[timestamp_field] + + return batch + + result_ds = result_ds.map_batches(finalize_result, batch_format="pandas") + + # Apply feature type conversion using unified operations + result_ds = _convert_feature_column_types(result_ds, regular_feature_views) + + # Storage path validation + storage_path = config.offline_store.storage_path + if not storage_path: + raise ValueError("Storage path must be set in config") + + # Create retrieval job following standard pattern + job = RayRetrievalJob(result_ds, staging_location=storage_path) + job._full_feature_names = full_feature_names + job._on_demand_feature_views = on_demand_feature_views + job._feature_refs = feature_refs + job._entity_df = entity_df_sample # Use sample for metadata creation + job._metadata = job._create_metadata() + return job From 266f4cdb2fed4a274b8e4d43718872e517664261 Mon Sep 17 00:00:00 2001 From: ntkathole Date: Fri, 11 Jul 2025 14:08:20 +0530 Subject: [PATCH 07/10] feat: Added Ray Compute Engine Signed-off-by: ntkathole --- Makefile | 5 + docs/SUMMARY.md | 2 + docs/reference/compute-engine/README.md | 8 + docs/reference/compute-engine/ray.md | 393 +++++++ docs/reference/offline-stores/ray.md | 542 ++++++--- .../api/v1alpha1/featurestore_types.go | 3 +- .../feast-operator.clusterserviceversion.yaml | 2 +- .../manifests/feast.dev_featurestores.yaml | 2 + .../crd/bases/feast.dev_featurestores.yaml | 2 + infra/feast-operator/dist/install.yaml | 2 + .../feast/infra/compute_engines/dag/model.py | 1 + .../infra/compute_engines/local/compute.py | 13 +- .../infra/compute_engines/ray/__init__.py | 38 + .../infra/compute_engines/ray/compute.py | 259 +++++ .../feast/infra/compute_engines/ray/config.py | 69 ++ .../compute_engines/ray/feature_builder.py | 224 ++++ .../feast/infra/compute_engines/ray/job.py | 296 +++++ .../feast/infra/compute_engines/ray/nodes.py | 660 +++++++++++ .../contrib/ray_offline_store/ray.py | 1025 +++++------------ .../contrib/ray_repo_configuration.py | 17 + sdk/python/feast/infra/ray_shared_utils.py | 363 ++++++ .../feast/infra/registry/caching_registry.py | 10 +- sdk/python/feast/repo_config.py | 1 + .../transformation/pandas_transformation.py | 53 +- sdk/python/tests/doctest/test_all.py | 89 +- sdk/python/tests/integration/__init__.py | 1 + .../integration/compute_engines/__init__.py | 1 + .../compute_engines/ray_compute/__init__.py | 1 + .../ray_compute/repo_configuration.py | 72 ++ .../ray_compute/test_compute.py | 291 +++++ sdk/python/tests/unit/__init__.py | 1 + .../unit/infra/compute_engines/__init__.py | 1 + .../compute_engines/ray_compute/__init__.py | 1 + .../compute_engines/ray_compute/test_nodes.py | 346 ++++++ 34 files changed, 3818 insertions(+), 976 deletions(-) create mode 100644 docs/reference/compute-engine/ray.md create mode 100644 sdk/python/feast/infra/compute_engines/ray/__init__.py create mode 100644 sdk/python/feast/infra/compute_engines/ray/compute.py create mode 100644 sdk/python/feast/infra/compute_engines/ray/config.py create mode 100644 sdk/python/feast/infra/compute_engines/ray/feature_builder.py create mode 100644 sdk/python/feast/infra/compute_engines/ray/job.py create mode 100644 sdk/python/feast/infra/compute_engines/ray/nodes.py create mode 100644 sdk/python/feast/infra/ray_shared_utils.py create mode 100644 sdk/python/tests/integration/__init__.py create mode 100644 sdk/python/tests/integration/compute_engines/__init__.py create mode 100644 sdk/python/tests/integration/compute_engines/ray_compute/__init__.py create mode 100644 sdk/python/tests/integration/compute_engines/ray_compute/repo_configuration.py create mode 100644 sdk/python/tests/integration/compute_engines/ray_compute/test_compute.py create mode 100644 sdk/python/tests/unit/__init__.py create mode 100644 sdk/python/tests/unit/infra/compute_engines/__init__.py create mode 100644 sdk/python/tests/unit/infra/compute_engines/ray_compute/__init__.py create mode 100644 sdk/python/tests/unit/infra/compute_engines/ray_compute/test_nodes.py diff --git a/Makefile b/Makefile index e088de84e99..7bc2570245c 100644 --- a/Makefile +++ b/Makefile @@ -320,6 +320,11 @@ test-python-universal-ray-offline: ## Run Python Ray offline store integration t not test_spark" \ sdk/python/tests +test-python-ray-compute-engine: ## Run Python Ray compute engine tests + PYTHONPATH='.' \ + python -m pytest --integration \ + sdk/python/tests/integration/compute_engines/ray_compute/ + test-python-universal-postgres-online: ## Run Python Postgres integration tests PYTHONPATH='.' \ FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.postgres_online_store.postgres_repo_configuration \ diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 05ddc3f7be7..2e34687d6c7 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -107,6 +107,7 @@ * [Trino (contrib)](reference/offline-stores/trino.md) * [Azure Synapse + Azure SQL (contrib)](reference/offline-stores/mssql.md) * [Clickhouse (contrib)](reference/offline-stores/clickhouse.md) + * [Ray (contrib)](reference/offline-stores/ray.md) * [Remote Offline](reference/offline-stores/remote-offline-store.md) * [Online stores](reference/online-stores/README.md) * [Overview](reference/online-stores/overview.md) @@ -143,6 +144,7 @@ * [Snowflake](reference/compute-engine/snowflake.md) * [AWS Lambda (alpha)](reference/compute-engine/lambda.md) * [Spark (contrib)](reference/compute-engine/spark.md) + * [Ray (contrib)](reference/compute-engine/ray.md) * [Feature repository](reference/feature-repository/README.md) * [feature\_store.yaml](reference/feature-repository/feature-store-yaml.md) * [.feastignore](reference/feature-repository/feast-ignore.md) diff --git a/docs/reference/compute-engine/README.md b/docs/reference/compute-engine/README.md index c4a2f87f54d..dad2ede75a6 100644 --- a/docs/reference/compute-engine/README.md +++ b/docs/reference/compute-engine/README.md @@ -57,6 +57,14 @@ An example of built output from FeatureBuilder: - Supports point-in-time joins and large-scale materialization - Integrates with `SparkOfflineStore` and `SparkMaterializationJob` +### ⚡ RayComputeEngine (contrib) + +- Distributed DAG execution via Ray +- Intelligent join strategies (broadcast vs distributed) +- Automatic resource management and optimization +- Integrates with `RayOfflineStore` and `RayMaterializationJob` +- See [Ray Compute Engine documentation](ray.md) for details + ### 🧪 LocalComputeEngine {% page-ref page="local.md" %} diff --git a/docs/reference/compute-engine/ray.md b/docs/reference/compute-engine/ray.md new file mode 100644 index 00000000000..a286867cd5f --- /dev/null +++ b/docs/reference/compute-engine/ray.md @@ -0,0 +1,393 @@ +# Ray Compute Engine (contrib) + +The Ray compute engine is a distributed compute implementation that leverages [Ray](https://www.ray.io/) for executing feature pipelines including transformations, aggregations, joins, and materializations. It provides scalable and efficient distributed processing for both `materialize()` and `get_historical_features()` operations. + +## Overview + +The Ray compute engine provides: +- **Distributed DAG Execution**: Executes feature computation DAGs across Ray clusters +- **Intelligent Join Strategies**: Automatic selection between broadcast and distributed joins +- **Lazy Evaluation**: Deferred execution for optimal performance +- **Resource Management**: Automatic scaling and resource optimization +- **Point-in-Time Joins**: Efficient temporal joins for historical feature retrieval + +## Architecture + +The Ray compute engine follows Feast's DAG-based architecture: + +``` +EntityDF → RayReadNode → RayJoinNode → RayFilterNode → RayAggregationNode → RayTransformationNode → Output +``` + +### Core Components + +| Component | Description | +|-----------|-------------| +| `RayComputeEngine` | Main engine implementing `ComputeEngine` interface | +| `RayFeatureBuilder` | Constructs DAG from Feature View definitions | +| `RayDAGNode` | Ray-specific DAG node implementations | +| `RayDAGRetrievalJob` | Executes retrieval plans and returns results | +| `RayMaterializationJob` | Handles materialization job tracking | + +## Configuration + +Configure the Ray compute engine in your `feature_store.yaml`: + +```yaml +project: my_project +registry: data/registry.db +provider: local +offline_store: + type: ray + storage_path: data/ray_storage +batch_engine: + type: ray.engine + max_workers: 4 # Optional: Maximum number of workers + enable_optimization: true # Optional: Enable performance optimizations + broadcast_join_threshold_mb: 100 # Optional: Broadcast join threshold (MB) + max_parallelism_multiplier: 2 # Optional: Parallelism multiplier + target_partition_size_mb: 64 # Optional: Target partition size (MB) + window_size_for_joins: "1H" # Optional: Time window for distributed joins + ray_address: localhost:10001 # Optional: Ray cluster address + use_ray_cluster: false # Optional: Use Ray cluster mode +``` + +### Configuration Options + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `type` | string | Required | Must be `ray.engine` | +| `max_workers` | int | CPU count | Maximum number of Ray workers | +| `enable_optimization` | boolean | true | Enable performance optimizations | +| `broadcast_join_threshold_mb` | int | 100 | Size threshold for broadcast joins (MB) | +| `max_parallelism_multiplier` | int | 2 | Parallelism as multiple of CPU cores | +| `target_partition_size_mb` | int | 64 | Target partition size (MB) | +| `window_size_for_joins` | string | "1H" | Time window for distributed joins | +| `ray_address` | string | None | Ray cluster address | +| `use_ray_cluster` | boolean | false | Use Ray cluster mode | + +## Usage Examples + +### Basic Historical Feature Retrieval + +```python +from feast import FeatureStore +import pandas as pd +from datetime import datetime + +# Initialize feature store with Ray compute engine +store = FeatureStore("feature_store.yaml") + +# Create entity DataFrame +entity_df = pd.DataFrame({ + "driver_id": [1, 2, 3, 4, 5], + "event_timestamp": [datetime.now()] * 5 +}) + +# Get historical features using Ray compute engine +features = store.get_historical_features( + entity_df=entity_df, + features=[ + "driver_stats:avg_daily_trips", + "driver_stats:total_distance" + ] +) + +# Convert to DataFrame +df = features.to_df() +print(f"Retrieved {len(df)} rows with {len(df.columns)} columns") +``` + +### Batch Materialization + +```python +from datetime import datetime, timedelta + +# Materialize features using Ray compute engine +store.materialize( + start_date=datetime.now() - timedelta(days=7), + end_date=datetime.now(), + feature_views=["driver_stats", "customer_stats"] +) + +# The Ray compute engine handles: +# - Distributed data processing +# - Optimal join strategies +# - Resource management +# - Progress tracking +``` + +### Large-Scale Feature Retrieval + +```python +# Handle large entity datasets efficiently +large_entity_df = pd.DataFrame({ + "driver_id": range(1, 1000000), # 1M entities + "event_timestamp": [datetime.now()] * 1000000 +}) + +# Ray compute engine automatically: +# - Partitions data optimally +# - Selects appropriate join strategies +# - Distributes computation across cluster +features = store.get_historical_features( + entity_df=large_entity_df, + features=[ + "driver_stats:avg_daily_trips", + "driver_stats:total_distance", + "customer_stats:lifetime_value" + ] +).to_df() +``` + +### Advanced Configuration + +```yaml +# Production-ready configuration +batch_engine: + type: ray.engine + # Resource configuration + max_workers: 16 + max_parallelism_multiplier: 4 + + # Performance optimization + enable_optimization: true + broadcast_join_threshold_mb: 50 + target_partition_size_mb: 128 + + # Distributed join configuration + window_size_for_joins: "30min" + + # Ray cluster configuration + use_ray_cluster: true + ray_address: "ray://head-node:10001" +``` + +### Complete Example Configuration + +Here's a complete example configuration showing how to use Ray offline store with Ray compute engine: + +```yaml +# Complete example configuration for Ray offline store + Ray compute engine +# This shows how to use both components together for distributed processing + +project: my_feast_project +registry: data/registry.db +provider: local + +# Ray offline store configuration +# Handles data I/O operations (reading/writing data) +offline_store: + type: ray + storage_path: s3://my-bucket/feast-data # Optional: Path for storing datasets + ray_address: localhost:10001 # Optional: Ray cluster address + use_ray_cluster: true # Optional: Use Ray cluster mode + +# Ray compute engine configuration +# Handles complex feature computation and distributed processing +batch_engine: + type: ray.engine + + # Resource configuration + max_workers: 8 # Maximum number of Ray workers + max_parallelism_multiplier: 2 # Parallelism as multiple of CPU cores + + # Performance optimization + enable_optimization: true # Enable performance optimizations + broadcast_join_threshold_mb: 100 # Broadcast join threshold (MB) + target_partition_size_mb: 64 # Target partition size (MB) + + # Distributed join configuration + window_size_for_joins: "1H" # Time window for distributed joins + + # Ray cluster configuration (inherits from offline_store if not specified) + ray_address: localhost:10001 # Ray cluster address + use_ray_cluster: true # Use Ray cluster mode + +# Optional: Online store configuration +online_store: + type: sqlite + path: data/online_store.db + +# Optional: Feature server configuration +feature_server: + port: 6566 + metrics_port: 8888 +``` + +## DAG Node Types + +The Ray compute engine implements several specialized DAG nodes: + +### RayReadNode + +Reads data from Ray-compatible sources: +- Supports Parquet, CSV, and other formats +- Handles partitioning and schema inference +- Applies field mappings and filters + +### RayJoinNode + +Performs distributed joins: +- **Broadcast Join**: For small datasets (<100MB) +- **Distributed Join**: For large datasets with time-based windowing +- **Automatic Strategy Selection**: Based on dataset size and cluster resources + +### RayFilterNode + +Applies filters and time-based constraints: +- TTL-based filtering +- Timestamp range filtering +- Custom predicate filtering + +### RayAggregationNode + +Handles feature aggregations: +- Windowed aggregations +- Grouped aggregations +- Custom aggregation functions + +### RayTransformationNode + +Applies feature transformations: +- Row-level transformations +- Column-level transformations +- Custom transformation functions + +### RayWriteNode + +Writes results to various targets: +- Online stores +- Offline stores +- Temporary storage + +## Join Strategies + +The Ray compute engine automatically selects optimal join strategies: + +### Broadcast Join + +Used for small feature datasets: +```python +# Automatically selected when feature data < 100MB +# Features are cached in Ray's object store +# Entities are distributed across cluster +# Each worker gets a copy of feature data +``` + +### Distributed Windowed Join + +Used for large feature datasets: +```python +# Automatically selected when feature data > 100MB +# Data is partitioned by time windows +# Point-in-time joins within each window +# Results are combined across windows +``` + +### Strategy Selection Logic + +```python +def select_join_strategy(feature_size_mb, threshold_mb): + if feature_size_mb < threshold_mb: + return "broadcast" + else: + return "distributed_windowed" +``` + +## Performance Optimization + +### Automatic Optimization + +The Ray compute engine includes several automatic optimizations: + +1. **Partition Optimization**: Automatically determines optimal partition sizes +2. **Join Strategy Selection**: Chooses between broadcast and distributed joins +3. **Resource Allocation**: Scales workers based on available resources +4. **Memory Management**: Handles out-of-core processing for large datasets + +### Manual Tuning + +For specific workloads, you can fine-tune performance: + +```yaml +batch_engine: + type: ray.engine + # Fine-tuning for high-throughput scenarios + broadcast_join_threshold_mb: 200 # Larger broadcast threshold + max_parallelism_multiplier: 1 # Conservative parallelism + target_partition_size_mb: 512 # Larger partitions + window_size_for_joins: "2H" # Larger time windows +``` + +### Monitoring and Metrics + +Monitor Ray compute engine performance: + +```python +import ray + +# Check cluster resources +resources = ray.cluster_resources() +print(f"Available CPUs: {resources.get('CPU', 0)}") +print(f"Available memory: {resources.get('memory', 0) / 1e9:.2f} GB") + +# Monitor job progress +job = store.get_historical_features(...) +# Ray compute engine provides built-in progress tracking +``` + +## Integration Examples + +### With Spark Offline Store + +```yaml +# Use Ray compute engine with Spark offline store +offline_store: + type: spark + spark_conf: + spark.executor.memory: "4g" + spark.executor.cores: "2" +batch_engine: + type: ray.engine + max_workers: 8 + enable_optimization: true +``` + +### With Cloud Storage + +```yaml +# Use Ray compute engine with cloud storage +offline_store: + type: ray + storage_path: s3://my-bucket/feast-data +batch_engine: + type: ray.engine + use_ray_cluster: true + ray_address: "ray://ray-cluster:10001" + broadcast_join_threshold_mb: 50 +``` + +### With Feature Transformations + +```python +from feast import FeatureView, Field +from feast.types import Float64 +from feast.on_demand_feature_view import on_demand_feature_view + +@on_demand_feature_view( + sources=["driver_stats"], + schema=[Field(name="trips_per_hour", dtype=Float64)] +) +def trips_per_hour(features_df): + features_df["trips_per_hour"] = features_df["avg_daily_trips"] / 24 + return features_df + +# Ray compute engine handles transformations efficiently +features = store.get_historical_features( + entity_df=entity_df, + features=["trips_per_hour:trips_per_hour"] +) +``` + +For more information, see the [Ray documentation](https://docs.ray.io/en/latest/) and [Ray Data guide](https://docs.ray.io/en/latest/data/getting-started.html). \ No newline at end of file diff --git a/docs/reference/offline-stores/ray.md b/docs/reference/offline-stores/ray.md index b0ba1e145c8..fc7baed5965 100644 --- a/docs/reference/offline-stores/ray.md +++ b/docs/reference/offline-stores/ray.md @@ -1,37 +1,70 @@ # Ray Offline Store (contrib) -The Ray offline store is a distributed offline store implementation that leverages [Ray](https://www.ray.io/) for distributed data processing. It's particularly useful for large-scale feature engineering and retrieval operations. +> **⚠️ Contrib Plugin:** +> The Ray offline store is a contributed plugin. It may not be as stable or fully supported as core offline stores. Use with caution in production and report issues to the Feast community. + +The Ray offline store is a data I/O implementation that leverages [Ray](https://www.ray.io/) for reading and writing data from various sources. It focuses on efficient data access operations, while complex feature computation is handled by the [Ray Compute Engine](../compute-engine/ray.md). ## Overview The Ray offline store provides: -- Distributed data processing using Ray -- Support for both local and cluster modes -- Integration with various storage backends (local files, S3, etc.) -- Support for scalable batch materialization -- Saved dataset persistence for data analysis and model training +- Ray-based data reading from file sources (Parquet, CSV, etc.) +- Support for both local and distributed Ray clusters +- Integration with various storage backends (local files, S3, GCS, HDFS) +- Efficient data filtering and column selection +- Timestamp-based data processing with timezone awareness + + +## Functionality Matrix + + +| Method | Supported | +|----------------------------------|-----------| +| get_historical_features | Yes | +| pull_latest_from_table_or_query | Yes | +| pull_all_from_table_or_query | Yes | +| offline_write_batch | Yes | +| write_logged_features | Yes | + -## Optimization Features +| RetrievalJob Feature | Supported | +|----------------------------------|-----------| +| export to dataframe | Yes | +| export to arrow table | Yes | +| persist results in offline store| Yes | +| local execution of ODFVs | Yes | +| remote execution of ODFVs | No | +| preview query plan | Yes | +| read partitioned data | Yes | -### Intelligent Join Strategies -The Ray offline store now includes intelligent join strategy selection: +## ⚠️ Important: Resource Management -- **Broadcast Joins**: For small feature datasets (<100MB by default), data is stored in Ray's object store for efficient broadcasting -- **Distributed Windowed Joins**: For large datasets, uses time-based windowing for distributed point-in-time joins -- **Automatic Strategy Selection**: Chooses optimal join strategy based on dataset size and cluster resources +**By default, Ray will use all available system resources (CPU and memory).** This can cause issues in test environments or when experimenting locally, potentially leading to system crashes or unresponsiveness. -### Resource Management +**For testing and local experimentation, we strongly recommend:** -The store automatically detects and optimizes for your Ray cluster: +1. **Configure resource limits** in your `feature_store.yaml` (see [Resource Management and Testing](#resource-management-and-testing) section below) -- **Auto-scaling**: Adjusts parallelism based on available CPU cores -- **Memory Optimization**: Configures buffer sizes based on available memory -- **Partition Optimization**: Calculates optimal partition sizes for your workload +This will limit Ray to safe resource levels for testing and development. + + +## Architecture + +The Ray offline store follows Feast's architectural separation: +- **Ray Offline Store**: Handles data I/O operations (reading/writing data) +- **Ray Compute Engine**: Handles complex feature computation and joins +- **Clear Separation**: Each component has a single responsibility + +For complex feature processing, historical feature retrieval, and distributed joins, use the [Ray Compute Engine](../compute-engine/ray.md). ## Configuration -The Ray offline store can be configured in your `feature_store.yaml` file: +The Ray offline store can be configured in your `feature_store.yaml` file. Below are two main configuration patterns: + +### Basic Ray Offline Store + +For simple data I/O operations without distributed processing: ```yaml project: my_project @@ -39,34 +72,211 @@ registry: data/registry.db provider: local offline_store: type: ray - storage_path: data/ray_storage # Optional: Path for materialized data + storage_path: data/ray_storage # Optional: Path for storing datasets ray_address: localhost:10001 # Optional: Ray cluster address use_ray_cluster: false # Optional: Whether to use Ray cluster - # New optimization settings - broadcast_join_threshold_mb: 100 # Optional: Threshold for broadcast joins (MB) - enable_distributed_joins: true # Optional: Enable distributed join strategies - max_parallelism_multiplier: 2 # Optional: Max parallelism as multiple of CPU cores - target_partition_size_mb: 64 # Optional: Target partition size (MB) - window_size_for_joins: "1H" # Optional: Time window size for distributed joins +``` + +### Ray Offline Store + Compute Engine + +For distributed feature processing with advanced capabilities: + +```yaml +project: my_project +registry: data/registry.db +provider: local + +# Ray offline store for data I/O operations +offline_store: + type: ray + storage_path: s3://my-bucket/feast-data # Optional: Path for storing datasets + ray_address: localhost:10001 # Optional: Ray cluster address + use_ray_cluster: true # Optional: Use Ray cluster mode + +# Ray compute engine for distributed feature processing +batch_engine: + type: ray.engine + + # Resource configuration + max_workers: 8 # Maximum number of Ray workers + max_parallelism_multiplier: 2 # Parallelism as multiple of CPU cores + + # Performance optimization + enable_optimization: true # Enable performance optimizations + broadcast_join_threshold_mb: 100 # Broadcast join threshold (MB) + target_partition_size_mb: 64 # Target partition size (MB) + + # Distributed join configuration + window_size_for_joins: "1H" # Time window for distributed joins + enable_distributed_joins: true # Enable distributed joins + + # Ray cluster configuration (optional) + ray_address: localhost:10001 # Ray cluster address + use_ray_cluster: true # Use Ray cluster mode + staging_location: s3://my-bucket/staging # Remote staging location +``` + +### Local Development Configuration + +For local development and testing: + +```yaml +project: my_local_project +registry: data/registry.db +provider: local + +offline_store: + type: ray + storage_path: ./data/ray_storage + # Conservative settings for local development + broadcast_join_threshold_mb: 25 + max_parallelism_multiplier: 1 + target_partition_size_mb: 16 + enable_ray_logging: false + # Memory constraints to prevent OOM in test/development environments + ray_conf: + num_cpus: 1 + object_store_memory: 104857600 # 100MB + _memory: 524288000 # 500MB + +batch_engine: + type: ray.engine + max_workers: 2 + enable_optimization: false +``` + +### Production Configuration + +For production deployments with distributed Ray cluster: + +```yaml +project: my_production_project +registry: s3://my-bucket/registry.db +provider: local + +offline_store: + type: ray + storage_path: s3://my-production-bucket/feast-data + ray_address: "ray://production-head-node:10001" + use_ray_cluster: true + +batch_engine: + type: ray.engine + max_workers: 32 + max_parallelism_multiplier: 4 + enable_optimization: true + broadcast_join_threshold_mb: 50 + target_partition_size_mb: 128 + window_size_for_joins: "30min" + ray_address: "ray://production-head-node:10001" + use_ray_cluster: true + staging_location: s3://my-production-bucket/staging ``` ### Configuration Options +#### Ray Offline Store Options + | Option | Type | Default | Description | |--------|------|---------|-------------| | `type` | string | Required | Must be `feast.offline_stores.contrib.ray_offline_store.ray.RayOfflineStore` or `ray` | -| `storage_path` | string | None | Path for storing materialized data (e.g., "s3://my-bucket/data") | +| `storage_path` | string | None | Path for storing temporary files and datasets | | `ray_address` | string | None | Address of the Ray cluster (e.g., "localhost:10001") | | `use_ray_cluster` | boolean | false | Whether to use Ray cluster mode | -| `broadcast_join_threshold_mb` | int | 100 | Size threshold (MB) below which broadcast joins are used | -| `enable_distributed_joins` | boolean | true | Enable intelligent distributed join strategies | -| `max_parallelism_multiplier` | int | 2 | Maximum parallelism as multiple of CPU cores | -| `target_partition_size_mb` | int | 64 | Target size for data partitions (MB) | -| `window_size_for_joins` | string | "1H" | Time window size for distributed temporal joins | +| `ray_conf` | dict | None | Ray initialization parameters for resource management (e.g., memory, CPU limits) | + +#### Ray Compute Engine Options + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `type` | string | Required | Must be `ray.engine` | +| `max_workers` | int | CPU count | Maximum number of Ray workers | +| `enable_optimization` | boolean | true | Enable performance optimizations | +| `broadcast_join_threshold_mb` | int | 100 | Size threshold for broadcast joins (MB) | +| `max_parallelism_multiplier` | int | 2 | Parallelism as multiple of CPU cores | +| `target_partition_size_mb` | int | 64 | Target partition size (MB) | +| `window_size_for_joins` | string | "1H" | Time window for distributed joins | +| `enable_distributed_joins` | boolean | true | Enable distributed joins for large datasets | +| `staging_location` | string | None | Remote path for batch materialization jobs | + +## Resource Management and Testing + +### Overview + +**By default, Ray will use all available system resources (CPU and memory).** This can cause issues in test environments or when experimenting locally, potentially leading to system crashes or unresponsiveness. + +### Resource Configuration + +For custom resource control, configure limits in your `feature_store.yaml`: + +#### Conservative Settings (Local Development/Testing) + +```yaml +offline_store: + type: ray + storage_path: ./data/ray_storage + # Resource optimization settings + broadcast_join_threshold_mb: 25 # Smaller datasets for broadcast joins + max_parallelism_multiplier: 1 # Reduced parallelism + target_partition_size_mb: 16 # Smaller partition sizes + enable_ray_logging: false # Disable verbose logging + # Memory constraints to prevent OOM in test environments + ray_conf: + num_cpus: 1 + object_store_memory: 104857600 # 100MB + _memory: 524288000 # 500MB +``` + +#### Production Settings + +```yaml +offline_store: + type: ray + storage_path: s3://my-bucket/feast-data + ray_address: "ray://production-cluster:10001" + use_ray_cluster: true + # Optimized for production workloads + broadcast_join_threshold_mb: 100 + max_parallelism_multiplier: 2 + target_partition_size_mb: 64 + enable_ray_logging: true +``` + +### Resource Configuration Options + +| Setting | Default | Description | Testing Recommendation | +|---------|---------|-------------|------------------------| +| `broadcast_join_threshold_mb` | 100 | Size threshold for broadcast joins (MB) | 25 | +| `max_parallelism_multiplier` | 2 | Parallelism as multiple of CPU cores | 1 | +| `target_partition_size_mb` | 64 | Target partition size (MB) | 16 | +| `enable_ray_logging` | false | Enable Ray progress bars and logging | false | + +### Environment-Specific Recommendations + +#### Local Development +```yaml +# feature_store.yaml +offline_store: + type: ray + broadcast_join_threshold_mb: 25 + max_parallelism_multiplier: 1 + target_partition_size_mb: 16 +``` + +#### Production Clusters +```yaml +# feature_store.yaml +offline_store: + type: ray + ray_address: "ray://cluster-head:10001" + use_ray_cluster: true + broadcast_join_threshold_mb: 200 + max_parallelism_multiplier: 4 +``` ## Usage Examples -### Basic Usage +### Basic Data Source Reading ```python from feast import FeatureStore, FeatureView, FileSource @@ -78,7 +288,6 @@ driver_stats = FeatureView( name="driver_stats", entities=["driver_id"], ttl=timedelta(days=1), - online=True, source=FileSource( path="data/driver_stats.parquet", timestamp_field="event_timestamp", @@ -92,82 +301,71 @@ driver_stats = FeatureView( # Initialize feature store store = FeatureStore("feature_store.yaml") -# Get historical features -entity_df = pd.DataFrame({ - "driver_id": [1, 2, 3], - "event_timestamp": [datetime.now()] * 3 -}) - -features = store.get_historical_features( - entity_df=entity_df, - features=[ - "driver_stats:avg_daily_trips" - ] -).to_df() +# The Ray offline store handles data I/O operations +# For complex feature computation, use Ray Compute Engine ``` -### Optimized Configuration for Large Datasets +### Direct Data Access -```yaml -offline_store: - type: ray - storage_path: s3://my-bucket/feast-data - use_ray_cluster: true - ray_address: ray://head-node:10001 - # Optimize for large datasets - broadcast_join_threshold_mb: 50 # Smaller threshold for large clusters - max_parallelism_multiplier: 4 # Higher parallelism for more CPUs - target_partition_size_mb: 128 # Larger partitions for better throughput - window_size_for_joins: "30min" # Smaller windows for better distribution -``` - -### High-Performance Feature Retrieval +The Ray offline store provides direct access to underlying data: ```python -# For large-scale feature retrieval with millions of entities -large_entity_df = pd.DataFrame({ - "driver_id": range(1, 1000000), # 1M drivers - "event_timestamp": [datetime.now()] * 1000000 -}) +from feast.infra.offline_stores.contrib.ray_offline_store.ray import RayOfflineStore +from datetime import datetime, timedelta + +# Pull latest data from a table +job = RayOfflineStore.pull_latest_from_table_or_query( + config=store.config, + data_source=driver_stats.source, + join_key_columns=["driver_id"], + feature_name_columns=["avg_daily_trips"], + timestamp_field="event_timestamp", + created_timestamp_column=None, + start_date=datetime.now() - timedelta(days=7), + end_date=datetime.now(), +) -# The Ray offline store will automatically: -# 1. Detect large dataset and use distributed joins -# 2. Partition data optimally across cluster -# 3. Use appropriate join strategy based on feature data size -features = store.get_historical_features( - entity_df=large_entity_df, - features=[ - "driver_stats:avg_daily_trips", - "driver_stats:total_distance" - ] -).to_df() +# Convert to pandas DataFrame +df = job.to_df() +print(f"Retrieved {len(df)} rows") + +# Convert to Arrow Table +arrow_table = job.to_arrow() + +# Get Ray dataset directly +ray_dataset = job.to_ray_dataset() ``` -### Saved Dataset Persistence +### Batch Writing -The Ray offline store supports persisting datasets for later analysis and model training: +The Ray offline store supports batch writing for materialization: ```python -from feast import FeatureStore -from feast.infra.offline_stores.file_source import SavedDatasetFileStorage +import pyarrow as pa +from feast import FeatureView -# Initialize feature store -store = FeatureStore("feature_store.yaml") - -# Get historical features -entity_df = pd.DataFrame({ +# Create sample data +data = pa.table({ "driver_id": [1, 2, 3, 4, 5], + "avg_daily_trips": [10.5, 15.2, 8.7, 12.3, 9.8], "event_timestamp": [datetime.now()] * 5 }) -# Create a retrieval job -job = store.get_historical_features( - entity_df=entity_df, - features=[ - "driver_stats:avg_daily_trips", - "driver_stats:total_trips" - ] +# Write batch data +RayOfflineStore.offline_write_batch( + config=store.config, + feature_view=driver_stats, + table=data, + progress=lambda x: print(f"Wrote {x} rows") ) +``` + +### Saved Dataset Persistence + +The Ray offline store supports persisting datasets for later analysis: + +```python +from feast.infra.offline_stores.file_source import SavedDatasetFileStorage # Create storage destination storage = SavedDatasetFileStorage(path="data/training_dataset.parquet") @@ -180,84 +378,33 @@ saved_dataset = store.create_saved_dataset( from_=job, name="driver_training_dataset", storage=storage, - tags={"purpose": "model_training", "version": "v1"} + tags={"purpose": "data_access", "version": "v1"} ) print(f"Saved dataset created: {saved_dataset.name}") ``` -### Remote Storage Persistence +### Remote Storage Support -You can persist datasets to remote storage for distributed access: +The Ray offline store supports various remote storage backends: ```python -# Persist to S3 +# S3 storage s3_storage = SavedDatasetFileStorage(path="s3://my-bucket/datasets/driver_features.parquet") job.persist(s3_storage, allow_overwrite=True) -# Persist to Google Cloud Storage +# Google Cloud Storage gcs_storage = SavedDatasetFileStorage(path="gs://my-project-bucket/datasets/driver_features.parquet") job.persist(gcs_storage, allow_overwrite=True) -# Persist to HDFS +# HDFS hdfs_storage = SavedDatasetFileStorage(path="hdfs://namenode:8020/datasets/driver_features.parquet") job.persist(hdfs_storage, allow_overwrite=True) ``` -### Retrieving Saved Datasets - -You can retrieve previously saved datasets: - -```python -# Retrieve a saved dataset -saved_dataset = store.get_saved_dataset("driver_training_dataset") - -# Convert to different formats -df = saved_dataset.to_df() # Pandas DataFrame -arrow_table = saved_dataset.to_arrow() # PyArrow Table - -# Get dataset metadata -print(f"Dataset features: {saved_dataset.features}") -print(f"Join keys: {saved_dataset.join_keys}") -print(f"Min timestamp: {saved_dataset.min_event_timestamp}") -print(f"Max timestamp: {saved_dataset.max_event_timestamp}") -``` - -### Batch Materialization with Persistence - -Combine batch materialization with dataset persistence: - -```python -from datetime import datetime, timedelta - -# Materialize features for the last 30 days -store.materialize( - start_date=datetime.now() - timedelta(days=30), - end_date=datetime.now(), - feature_views=["driver_stats"] -) - -# Get historical features for the materialized period -entity_df = pd.DataFrame({ - "driver_id": list(range(1, 1001)), # 1000 drivers - "event_timestamp": [datetime.now()] * 1000 -}) - -job = store.get_historical_features( - entity_df=entity_df, - features=["driver_stats:avg_daily_trips"] -) - -# Persist to remote storage for distributed access -remote_storage = SavedDatasetFileStorage( - path="s3://my-bucket/large_datasets/driver_features_30d.parquet" -) -job.persist(remote_storage, allow_overwrite=True) -``` - ### Using Ray Cluster -To use Ray in cluster mode for maximum performance: +To use Ray in cluster mode for distributed data access: 1. Start a Ray cluster: ```bash @@ -270,9 +417,7 @@ offline_store: type: ray ray_address: localhost:10001 use_ray_cluster: true - # Cluster-optimized settings - max_parallelism_multiplier: 3 - target_partition_size_mb: 256 + storage_path: s3://my-bucket/features ``` 3. For multiple worker nodes: @@ -281,36 +426,87 @@ offline_store: ray start --address='head-node-ip:10001' ``` -### Remote Storage +### Data Source Validation + +The Ray offline store validates data sources to ensure compatibility: + +```python +from feast.infra.offline_stores.contrib.ray_offline_store.ray import RayOfflineStore + +# Validate a data source +try: + RayOfflineStore.validate_data_source(store.config, driver_stats.source) + print("Data source is valid") +except Exception as e: + print(f"Data source validation failed: {e}") +``` + +## Limitations + +The Ray offline store has the following limitations: + +1. **File Sources Only**: Currently supports only `FileSource` data sources +2. **No Direct SQL**: Does not support SQL query interfaces +3. **No Online Writes**: Cannot write directly to online stores +4. **Limited Transformations**: Complex feature transformations should use the Ray Compute Engine + +## Integration with Ray Compute Engine + +For complex feature processing operations, use the Ray offline store in combination with the [Ray Compute Engine](../compute-engine/ray.md). See the **Ray Offline Store + Compute Engine** configuration example in the [Configuration](#configuration) section above for a complete setup. + +The Ray offline store provides the data I/O foundation, while the Ray compute engine handles: +- **Point-in-time joins**: Efficient temporal joins for historical feature retrieval +- **Feature aggregations**: Distributed aggregations across time windows +- **Complex transformations**: Advanced feature transformations and computations +- **Historical feature retrieval**: `get_historical_features()` with distributed processing +- **Distributed processing optimization**: Automatic join strategy selection and resource management +- **Materialization**: Distributed batch materialization with progress tracking -For large-scale materialization, you can use remote storage: +For more advanced troubleshooting, refer to the [Ray documentation](https://docs.ray.io/en/latest/data/getting-started.html). + +## Quick Reference + +### Configuration Templates + +**Basic Ray Offline Store** (local development): ```yaml offline_store: type: ray - storage_path: s3://my-bucket/features + storage_path: ./data/ray_storage + # Conservative settings for local development + broadcast_join_threshold_mb: 25 + max_parallelism_multiplier: 1 + target_partition_size_mb: 16 + enable_ray_logging: false ``` -```python -# Materialize features to remote storage -store.materialize( - start_date=datetime.now() - timedelta(days=7), - end_date=datetime.now(), - feature_views=["driver_stats"] -) +**Ray Offline Store + Compute Engine** (distributed processing): +```yaml +offline_store: + type: ray + storage_path: s3://my-bucket/feast-data + use_ray_cluster: true + +batch_engine: + type: ray.engine + max_workers: 8 + enable_optimization: true + broadcast_join_threshold_mb: 100 ``` +### Key Commands -### Custom Optimization +```python +# Initialize feature store +store = FeatureStore("feature_store.yaml") -For specific workloads, you can fine-tune the configuration: +# Get historical features (uses compute engine if configured) +features = store.get_historical_features(entity_df=df, features=["fv:feature"]) -```yaml -offline_store: - type: ray - # Fine-tuning for high-throughput scenarios - broadcast_join_threshold_mb: 200 # Larger broadcast threshold - max_parallelism_multiplier: 1 # Conservative parallelism - target_partition_size_mb: 512 # Larger partitions - window_size_for_joins: "2H" # Larger time windows +# Direct data access (uses offline store) +job = RayOfflineStore.pull_latest_from_table_or_query(...) +df = job.to_df() ``` + +For complete examples, see the [Configuration](#configuration) section above. \ No newline at end of file diff --git a/infra/feast-operator/api/v1alpha1/featurestore_types.go b/infra/feast-operator/api/v1alpha1/featurestore_types.go index 8587fc98240..756c2e17ab1 100644 --- a/infra/feast-operator/api/v1alpha1/featurestore_types.go +++ b/infra/feast-operator/api/v1alpha1/featurestore_types.go @@ -315,7 +315,7 @@ var ValidOfflineStoreFilePersistenceTypes = []string{ // OfflineStoreDBStorePersistence configures the DB store persistence for the offline store service type OfflineStoreDBStorePersistence struct { // Type of the persistence type you want to use. - // +kubebuilder:validation:Enum=snowflake.offline;bigquery;redshift;spark;postgres;trino;athena;mssql;couchbase.offline;clickhouse + // +kubebuilder:validation:Enum=snowflake.offline;bigquery;redshift;spark;postgres;trino;athena;mssql;couchbase.offline;clickhouse;ray Type string `json:"type"` // Data store parameters should be placed as-is from the "feature_store.yaml" under the secret key. "registry_type" & "type" fields should be removed. SecretRef corev1.LocalObjectReference `json:"secretRef"` @@ -334,6 +334,7 @@ var ValidOfflineStoreDBStorePersistenceTypes = []string{ "mssql", "couchbase.offline", "clickhouse", + "ray", } // OnlineStore configures the online store service diff --git a/infra/feast-operator/bundle/manifests/feast-operator.clusterserviceversion.yaml b/infra/feast-operator/bundle/manifests/feast-operator.clusterserviceversion.yaml index 4d585d0feac..fcc382974f9 100644 --- a/infra/feast-operator/bundle/manifests/feast-operator.clusterserviceversion.yaml +++ b/infra/feast-operator/bundle/manifests/feast-operator.clusterserviceversion.yaml @@ -50,7 +50,7 @@ metadata: } ] capabilities: Basic Install - createdAt: "2025-07-21T20:53:09Z" + createdAt: "2025-07-25T09:58:54Z" operators.operatorframework.io/builder: operator-sdk-v1.38.0 operators.operatorframework.io/project_layout: go.kubebuilder.io/v4 name: feast-operator.v0.51.0 diff --git a/infra/feast-operator/bundle/manifests/feast.dev_featurestores.yaml b/infra/feast-operator/bundle/manifests/feast.dev_featurestores.yaml index b7718e57104..701ed9bf052 100644 --- a/infra/feast-operator/bundle/manifests/feast.dev_featurestores.yaml +++ b/infra/feast-operator/bundle/manifests/feast.dev_featurestores.yaml @@ -842,6 +842,7 @@ spec: - mssql - couchbase.offline - clickhouse + - ray type: string required: - secretRef @@ -4806,6 +4807,7 @@ spec: - mssql - couchbase.offline - clickhouse + - ray type: string required: - secretRef diff --git a/infra/feast-operator/config/crd/bases/feast.dev_featurestores.yaml b/infra/feast-operator/config/crd/bases/feast.dev_featurestores.yaml index b2fed6992d5..360fbba5453 100644 --- a/infra/feast-operator/config/crd/bases/feast.dev_featurestores.yaml +++ b/infra/feast-operator/config/crd/bases/feast.dev_featurestores.yaml @@ -842,6 +842,7 @@ spec: - mssql - couchbase.offline - clickhouse + - ray type: string required: - secretRef @@ -4806,6 +4807,7 @@ spec: - mssql - couchbase.offline - clickhouse + - ray type: string required: - secretRef diff --git a/infra/feast-operator/dist/install.yaml b/infra/feast-operator/dist/install.yaml index b79489f7a29..add5c13c9a5 100644 --- a/infra/feast-operator/dist/install.yaml +++ b/infra/feast-operator/dist/install.yaml @@ -850,6 +850,7 @@ spec: - mssql - couchbase.offline - clickhouse + - ray type: string required: - secretRef @@ -4814,6 +4815,7 @@ spec: - mssql - couchbase.offline - clickhouse + - ray type: string required: - secretRef diff --git a/sdk/python/feast/infra/compute_engines/dag/model.py b/sdk/python/feast/infra/compute_engines/dag/model.py index f77fdd0b6c9..5990eea6141 100644 --- a/sdk/python/feast/infra/compute_engines/dag/model.py +++ b/sdk/python/feast/infra/compute_engines/dag/model.py @@ -5,3 +5,4 @@ class DAGFormat(str, Enum): SPARK = "spark" PANDAS = "pandas" ARROW = "arrow" + RAY = "ray" diff --git a/sdk/python/feast/infra/compute_engines/local/compute.py b/sdk/python/feast/infra/compute_engines/local/compute.py index 341b20dee02..556468f5e1d 100644 --- a/sdk/python/feast/infra/compute_engines/local/compute.py +++ b/sdk/python/feast/infra/compute_engines/local/compute.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, Union +from typing import Literal, Optional, Sequence, Union from feast import ( BatchFeatureView, @@ -22,6 +22,17 @@ LocalRetrievalJob, ) from feast.infra.registry.base_registry import BaseRegistry +from feast.repo_config import FeastConfigBaseModel + + +class LocalComputeEngineConfig(FeastConfigBaseModel): + """Configuration for Local Compute Engine.""" + + type: Literal["local"] = "local" + """Local Compute Engine type selector""" + + backend: Optional[str] = None + """Backend to use for DataFrame operations (e.g., 'pandas', 'polars')""" class LocalComputeEngine(ComputeEngine): diff --git a/sdk/python/feast/infra/compute_engines/ray/__init__.py b/sdk/python/feast/infra/compute_engines/ray/__init__.py new file mode 100644 index 00000000000..7b02d0ca615 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/ray/__init__.py @@ -0,0 +1,38 @@ +""" +Ray Compute Engine for Feast + +This module provides a Ray-based compute engine for distributed feature computation. +It includes: +- RayComputeEngine: Main compute engine implementation +- RayComputeEngineConfig: Configuration for the compute engine +- Ray DAG nodes for distributed processing +""" + +from .compute import RayComputeEngine +from .config import RayComputeEngineConfig +from .feature_builder import RayFeatureBuilder +from .job import RayDAGRetrievalJob, RayMaterializationJob +from .nodes import ( + RayAggregationNode, + RayDedupNode, + RayFilterNode, + RayJoinNode, + RayReadNode, + RayTransformationNode, + RayWriteNode, +) + +__all__ = [ + "RayComputeEngine", + "RayComputeEngineConfig", + "RayDAGRetrievalJob", + "RayMaterializationJob", + "RayFeatureBuilder", + "RayReadNode", + "RayJoinNode", + "RayFilterNode", + "RayAggregationNode", + "RayDedupNode", + "RayTransformationNode", + "RayWriteNode", +] diff --git a/sdk/python/feast/infra/compute_engines/ray/compute.py b/sdk/python/feast/infra/compute_engines/ray/compute.py new file mode 100644 index 00000000000..3363d483a06 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/ray/compute.py @@ -0,0 +1,259 @@ +import logging +from datetime import datetime +from typing import Sequence, Union + +import ray + +from feast import ( + BatchFeatureView, + Entity, + FeatureView, + OnDemandFeatureView, + StreamFeatureView, +) +from feast.infra.common.materialization_job import ( + MaterializationJob, + MaterializationJobStatus, + MaterializationTask, +) +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.base import ComputeEngine +from feast.infra.compute_engines.ray.config import RayComputeEngineConfig +from feast.infra.compute_engines.ray.feature_builder import RayFeatureBuilder +from feast.infra.compute_engines.ray.job import ( + RayDAGRetrievalJob, + RayMaterializationJob, +) +from feast.infra.offline_stores.offline_store import RetrievalJob +from feast.infra.registry.base_registry import BaseRegistry + +logger = logging.getLogger(__name__) + + +class RayComputeEngine(ComputeEngine): + """ + Ray-based compute engine for distributed feature computation. + This engine uses Ray for distributed processing of features, enabling + efficient point-in-time joins, aggregations, and transformations across + large datasets. + """ + + def __init__( + self, + offline_store, + online_store, + repo_config, + **kwargs, + ): + super().__init__( + offline_store=offline_store, + online_store=online_store, + repo_config=repo_config, + **kwargs, + ) + self.config = repo_config.batch_engine + assert isinstance(self.config, RayComputeEngineConfig) + self._ensure_ray_initialized() + + def _ensure_ray_initialized(self): + """Ensure Ray is initialized with proper configuration.""" + if not ray.is_initialized(): + if self.config.use_ray_cluster and self.config.ray_address: + ray.init( + address=self.config.ray_address, + ignore_reinit_error=True, + include_dashboard=False, + ) + else: + ray_init_args = { + "ignore_reinit_error": True, + "include_dashboard": False, + } + + # Add configuration from ray_conf if provided + if self.config.ray_conf: + ray_init_args.update(self.config.ray_conf) + + ray.init(**ray_init_args) + + # Configure Ray context for optimal performance + from ray.data.context import DatasetContext + + ctx = DatasetContext.get_current() + ctx.enable_tensor_extension_casting = False + + # Log Ray cluster information + cluster_resources = ray.cluster_resources() + logger.info( + f"Ray cluster initialized with {cluster_resources.get('CPU', 0)} CPUs, " + f"{cluster_resources.get('memory', 0) / (1024**3):.1f}GB memory" + ) + + def update( + self, + project: str, + views_to_delete: Sequence[ + Union[BatchFeatureView, StreamFeatureView, FeatureView] + ], + views_to_keep: Sequence[ + Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] + ], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + ): + """Ray compute engine doesn't require infrastructure updates.""" + pass + + def teardown_infra( + self, + project: str, + fvs: Sequence[Union[BatchFeatureView, StreamFeatureView, FeatureView]], + entities: Sequence[Entity], + ): + """Ray compute engine doesn't require infrastructure teardown.""" + pass + + def _materialize_one( + self, + registry: BaseRegistry, + task: MaterializationTask, + from_offline_store: bool = False, + **kwargs, + ) -> MaterializationJob: + """Materialize features for a single feature view.""" + job_id = f"{task.feature_view.name}-{task.start_time}-{task.end_time}" + + if from_offline_store: + logger.warning( + "Materializing from offline store will be deprecated. " + "Please use the new materialization API." + ) + return self._materialize_from_offline_store( + registry=registry, + feature_view=task.feature_view, + start_date=task.start_time, + end_date=task.end_time, + project=task.project, + ) + + try: + # Build typed execution context + context = self.get_execution_context(registry, task) + + # Construct Feature Builder and execute + builder = RayFeatureBuilder(registry, task.feature_view, task, self.config) + plan = builder.build() + result = plan.execute(context) + + # Log execution results + logger.info(f"Materialization completed for {task.feature_view.name}") + + return RayMaterializationJob( + job_id=job_id, + status=MaterializationJobStatus.SUCCEEDED, + result=result, + ) + + except Exception as e: + logger.error(f"Materialization failed for {task.feature_view.name}: {e}") + return RayMaterializationJob( + job_id=job_id, + status=MaterializationJobStatus.ERROR, + error=e, + ) + + def _materialize_from_offline_store( + self, + registry: BaseRegistry, + feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], + start_date: datetime, + end_date: datetime, + project: str, + ) -> MaterializationJob: + """Legacy materialization method for backward compatibility.""" + from feast.utils import _get_column_names + + job_id = f"{feature_view.name}-{start_date}-{end_date}" + + try: + # Get column information + entities = [ + registry.get_entity(name, project) for name in feature_view.entities + ] + ( + join_key_columns, + feature_name_columns, + timestamp_field, + created_timestamp_column, + ) = _get_column_names(feature_view, entities) + + # Pull data from offline store + retrieval_job = self.offline_store.pull_latest_from_table_or_query( + config=self.repo_config, + data_source=feature_view.batch_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, + ) + + # Convert to Arrow Table and write to online store + arrow_table = retrieval_job.to_arrow() + # TODO: Implement proper online store writing with correct data format conversion + # self.online_store.online_write_batch(...) + logger.debug( + f"Materialization completed, arrow table has {arrow_table.num_rows} rows" + ) + + return RayMaterializationJob( + job_id=job_id, + status=MaterializationJobStatus.SUCCEEDED, + ) + + except Exception as e: + logger.error(f"Legacy materialization failed: {e}") + return RayMaterializationJob( + job_id=job_id, + status=MaterializationJobStatus.ERROR, + error=e, + ) + + def get_historical_features( + self, registry: BaseRegistry, task: HistoricalRetrievalTask + ) -> RetrievalJob: + """Get historical features using Ray DAG execution.""" + if isinstance(task.entity_df, str): + raise NotImplementedError( + "SQL-based entity_df is not yet supported in Ray DAG" + ) + + try: + # Build typed execution context + context = self.get_execution_context(registry, task) + + # Construct Feature Builder and build execution plan + builder = RayFeatureBuilder(registry, task.feature_view, task, self.config) + plan = builder.build() + + return RayDAGRetrievalJob( + plan=plan, + context=context, + config=self.repo_config, + full_feature_names=task.full_feature_name, + on_demand_feature_views=getattr(task, "on_demand_feature_views", None), + feature_refs=getattr(task, "feature_refs", None), + ) + + except Exception as e: + logger.error(f"Historical feature retrieval failed: {e}") + return RayDAGRetrievalJob( + plan=None, + context=None, + config=self.repo_config, + full_feature_names=task.full_feature_name, + on_demand_feature_views=getattr(task, "on_demand_feature_views", None), + feature_refs=getattr(task, "feature_refs", None), + error=e, + ) diff --git a/sdk/python/feast/infra/compute_engines/ray/config.py b/sdk/python/feast/infra/compute_engines/ray/config.py new file mode 100644 index 00000000000..5fc66b49659 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/ray/config.py @@ -0,0 +1,69 @@ +"""Configuration for Ray compute engine.""" + +from datetime import timedelta +from typing import Dict, Literal, Optional + +from pydantic import StrictStr + +from feast.repo_config import FeastConfigBaseModel + + +class RayComputeEngineConfig(FeastConfigBaseModel): + """Configuration for Ray Compute Engine.""" + + type: Literal["ray.engine"] = "ray.engine" + """Ray Compute Engine type selector""" + + ray_address: Optional[str] = None + """Ray cluster address. If None, uses local Ray cluster.""" + + use_ray_cluster: bool = False + """Whether to use an existing Ray cluster.""" + + staging_location: Optional[StrictStr] = None + """Remote path for batch materialization jobs""" + + # Ray-specific performance configurations + broadcast_join_threshold_mb: int = 100 + """Threshold for using broadcast joins (in MB)""" + + enable_distributed_joins: bool = True + """Whether to enable distributed joins for large datasets""" + + max_parallelism_multiplier: int = 2 + """Multiplier for max parallelism based on available CPUs""" + + target_partition_size_mb: int = 64 + """Target partition size in MB""" + + window_size_for_joins: str = "1H" + """Window size for windowed temporal joins""" + + ray_conf: Optional[Dict[str, str]] = None + """Ray configuration parameters""" + + # Additional configuration options + max_workers: Optional[int] = None + """Maximum number of Ray workers. If None, uses all available cores.""" + + enable_optimization: bool = True + """Enable automatic performance optimizations.""" + + execution_timeout_seconds: Optional[int] = None + """Timeout for job execution in seconds.""" + + @property + def window_size_timedelta(self) -> timedelta: + """Convert window size string to timedelta.""" + if self.window_size_for_joins.endswith("H"): + hours = int(self.window_size_for_joins[:-1]) + return timedelta(hours=hours) + elif self.window_size_for_joins.endswith("min"): + minutes = int(self.window_size_for_joins[:-3]) + return timedelta(minutes=minutes) + elif self.window_size_for_joins.endswith("s"): + seconds = int(self.window_size_for_joins[:-1]) + return timedelta(seconds=seconds) + else: + # Default to 1 hour + return timedelta(hours=1) diff --git a/sdk/python/feast/infra/compute_engines/ray/feature_builder.py b/sdk/python/feast/infra/compute_engines/ray/feature_builder.py new file mode 100644 index 00000000000..7f49accd0a0 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/ray/feature_builder.py @@ -0,0 +1,224 @@ +import logging +from typing import TYPE_CHECKING, Union + +from feast.infra.common.materialization_job import MaterializationTask +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.dag.plan import ExecutionPlan +from feast.infra.compute_engines.feature_builder import FeatureBuilder +from feast.infra.compute_engines.ray.nodes import ( + RayAggregationNode, + RayDedupNode, + RayFilterNode, + RayJoinNode, + RayReadNode, + RayTransformationNode, + RayWriteNode, +) + +if TYPE_CHECKING: + from feast.infra.compute_engines.ray.config import RayComputeEngineConfig + +logger = logging.getLogger(__name__) + + +class RayFeatureBuilder(FeatureBuilder): + """ + Ray-specific feature builder that constructs execution plans using Ray DAG nodes. + This builder translates FeatureView definitions into Ray-optimized execution DAGs + that can leverage distributed computing for large-scale feature processing. + """ + + def __init__( + self, + registry, + feature_view, + task: Union[MaterializationTask, HistoricalRetrievalTask], + config: "RayComputeEngineConfig", + ): + super().__init__(registry, feature_view, task) + self.config = config + self.is_historical_retrieval = isinstance(task, HistoricalRetrievalTask) + + def build_source_node(self, view): + """Build the source node for reading feature data.""" + source = view.batch_source + start_time = self.task.start_time + end_time = self.task.end_time + column_info = self.get_column_info(view) + + node = RayReadNode( + name="source", + source=source, + column_info=column_info, + config=self.config, + start_time=start_time, + end_time=end_time, + ) + + self.nodes.append(node) + logger.debug(f"Built source node for {source}") + return node + + def build_join_node(self, view, input_nodes): + """Build the join node for entity-feature joining.""" + column_info = self.get_column_info(view) + node = RayJoinNode( + name="join", + column_info=column_info, + config=self.config, + # Pass entity_df information if this is a historical retrieval + is_historical_retrieval=self.is_historical_retrieval, + ) + for input_node in input_nodes: + node.add_input(input_node) + self.nodes.append(node) + logger.debug("Built join node") + return node + + def build_filter_node(self, view, input_node): + """Build the filter node for TTL and custom filtering.""" + filter_expr = None + if hasattr(view, "filter"): + filter_expr = view.filter + + ttl = getattr(view, "ttl", None) + column_info = self.get_column_info(view) + + node = RayFilterNode( + name="filter", + column_info=column_info, + config=self.config, + ttl=ttl, + filter_condition=filter_expr, + ) + + node.add_input(input_node) + self.nodes.append(node) + logger.debug(f"Built filter node with TTL: {ttl}") + return node + + def build_aggregation_node(self, view, input_node): + """Build the aggregation node for feature aggregations.""" + if not hasattr(view, "aggregations"): + raise ValueError("Feature view does not have aggregations") + + aggregations = view.aggregations + group_by_keys = view.entities + + # Get timestamp field from batch source + timestamp_field = getattr( + view.batch_source, "timestamp_field", "event_timestamp" + ) + + node = RayAggregationNode( + name="aggregation", + aggregations=aggregations, + group_by_keys=group_by_keys, + timestamp_col=timestamp_field, + config=self.config, + ) + + node.add_input(input_node) + self.nodes.append(node) + logger.debug(f"Built aggregation node with {len(aggregations)} aggregations") + return node + + def build_dedup_node(self, view, input_node): + """Build the deduplication node for removing duplicates.""" + column_info = self.get_column_info(view) + node = RayDedupNode( + name="dedup", + column_info=column_info, + config=self.config, + ) + + node.add_input(input_node) + self.nodes.append(node) + logger.debug("Built dedup node") + return node + + def build_transformation_node(self, view, input_nodes): + """Build the transformation node for feature transformations.""" + transformation = None + + # Check for feature_transformation first + if hasattr(view, "feature_transformation") and view.feature_transformation: + transformation = view.feature_transformation + # For BatchFeatureView, also check for direct UDF + elif hasattr(view, "udf") and view.udf: + transformation = view.udf + else: + raise ValueError("Feature view does not have feature transformation or UDF") + + node = RayTransformationNode( + name="transformation", + transformation=transformation, + config=self.config, + ) + + for input_node in input_nodes: + node.add_input(input_node) + self.nodes.append(node) + transformation_name = getattr( + transformation, "name", getattr(transformation, "__name__", "unknown") + ) + logger.debug(f"Built transformation node: {transformation_name}") + return node + + def build_output_nodes(self, view, final_node): + """Build the output node for writing results.""" + node = RayWriteNode( + name="output", + feature_view=view, + config=self.config, + ) + + node.add_input(final_node) + self.nodes.append(node) + logger.debug("Built output node") + return node + + def build_validation_node(self, view, input_node): + """Build the validation node for data quality checks.""" + # For now, validation is handled in the retrieval job + # This could be extended to include Ray-specific validation logic + logger.debug("Validation node not implemented yet") + return input_node + + def build(self) -> ExecutionPlan: + """Build execution plan with optimized order for aggregation scenarios.""" + + # For historical retrieval with aggregations, use a different execution order + if self.is_historical_retrieval and self._should_aggregate(self.feature_view): + return self._build_aggregation_optimized_plan() + + # Use the default build logic for other scenarios + return super().build() + + def _build_aggregation_optimized_plan(self) -> ExecutionPlan: + """Build execution plan optimized for aggregation scenarios.""" + + # 1. Read source data + last_node = self.build_source_node(self.feature_view) + + # 2. Apply filters (TTL, custom filters) BEFORE aggregation + last_node = self.build_filter_node(self.feature_view, last_node) + + # 3. Aggregate across all historical records + last_node = self.build_aggregation_node(self.feature_view, last_node) + + # 4. Join with entity_df to get aggregated features for each entity + last_node = self.build_join_node(self.feature_view, [last_node]) + + # 5. Apply transformations to aggregated features + if self._should_transform(self.feature_view): + last_node = self.build_transformation_node(self.feature_view, [last_node]) + + # 6. Validation if needed + if self._should_validate(self.feature_view): + last_node = self.build_validation_node(self.feature_view, last_node) + + # 7. Output + last_node = self.build_output_nodes(self.feature_view, last_node) + + return ExecutionPlan(self.nodes) diff --git a/sdk/python/feast/infra/compute_engines/ray/job.py b/sdk/python/feast/infra/compute_engines/ray/job.py new file mode 100644 index 00000000000..bfc0943ca95 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/ray/job.py @@ -0,0 +1,296 @@ +import logging +import uuid +from dataclasses import dataclass +from typing import List, Optional + +import pandas as pd +import pyarrow as pa +import ray +from ray.data import Dataset + +from feast import OnDemandFeatureView +from feast.dqm.errors import ValidationFailed +from feast.errors import SavedDatasetLocationAlreadyExists +from feast.infra.common.materialization_job import ( + MaterializationJob, + MaterializationJobStatus, +) +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.model import DAGFormat +from feast.infra.compute_engines.dag.plan import ExecutionPlan +from feast.infra.compute_engines.dag.value import DAGValue +from feast.infra.offline_stores.file_source import SavedDatasetFileStorage +from feast.infra.offline_stores.offline_store import RetrievalJob, RetrievalMetadata +from feast.repo_config import RepoConfig +from feast.saved_dataset import SavedDatasetStorage + +logger = logging.getLogger(__name__) + + +class RayDAGRetrievalJob(RetrievalJob): + """ + Ray-based retrieval job that executes a DAG plan to retrieve historical features. + """ + + def __init__( + self, + plan: Optional[ExecutionPlan], + context: Optional[ExecutionContext], + config: RepoConfig, + full_feature_names: bool, + on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None, + feature_refs: Optional[List[str]] = None, + metadata: Optional[RetrievalMetadata] = None, + error: Optional[BaseException] = None, + ): + super().__init__() + self._plan = plan + self._context = context + self._config = config + self._full_feature_names = full_feature_names + self._on_demand_feature_views = on_demand_feature_views or [] + self._feature_refs = feature_refs or [] + self._metadata = metadata + self._error = error + self._result_dataset: Optional[Dataset] = None + self._result_df: Optional[pd.DataFrame] = None + self._result_arrow: Optional[pa.Table] = None + + def error(self) -> Optional[BaseException]: + """Return any error that occurred during job execution.""" + return self._error + + def _ensure_executed(self) -> DAGValue: + """Ensure the execution plan has been executed.""" + if self._result_dataset is None and self._plan and self._context: + try: + result = self._plan.execute(self._context) + if hasattr(result, "data") and isinstance(result.data, Dataset): + self._result_dataset = result.data + else: + # If result is not a Ray Dataset, convert it + if isinstance(result.data, pd.DataFrame): + self._result_dataset = ray.data.from_pandas(result.data) + elif isinstance(result.data, pa.Table): + self._result_dataset = ray.data.from_arrow(result.data) + else: + raise ValueError( + f"Unsupported result type: {type(result.data)}" + ) + return result + except Exception as e: + self._error = e + logger.error(f"Ray DAG execution failed: {e}") + raise + elif self._result_dataset is None: + raise ValueError("No execution plan available or execution failed") + + # Return a mock DAGValue for compatibility + return DAGValue(data=self._result_dataset, format=DAGFormat.RAY) + + def to_ray_dataset(self) -> Dataset: + """Get the result as a Ray Dataset.""" + self._ensure_executed() + assert self._result_dataset is not None, ( + "Dataset should not be None after execution" + ) + return self._result_dataset + + def to_df( + self, + validation_reference=None, + timeout: Optional[int] = None, + ) -> pd.DataFrame: + """Convert the result to a pandas DataFrame.""" + if self._result_df is None: + if self.on_demand_feature_views: + # Use parent implementation for ODFV processing + logger.info( + f"Processing {len(self.on_demand_feature_views)} on-demand feature views" + ) + self._result_df = super().to_df( + validation_reference=validation_reference, timeout=timeout + ) + else: + # Direct conversion from Ray Dataset + self._ensure_executed() + assert self._result_dataset is not None, ( + "Dataset should not be None after execution" + ) + self._result_df = self._result_dataset.to_pandas() + + # Handle validation if provided + if validation_reference: + try: + validation_result = validation_reference.profile.validate( + self._result_df + ) + if not validation_result.is_success: + raise ValidationFailed(validation_result) + except ImportError: + logger.warning("DQM profiler not available, skipping validation") + except Exception as e: + logger.error(f"Validation failed: {e}") + raise ValueError(f"Data validation failed: {e}") + + return self._result_df + + def to_arrow( + self, + validation_reference=None, + timeout: Optional[int] = None, + ) -> pa.Table: + """Convert the result to an Arrow Table.""" + if self._result_arrow is None: + if self.on_demand_feature_views: + # Use parent implementation for ODFV processing + self._result_arrow = super().to_arrow( + validation_reference=validation_reference, timeout=timeout + ) + else: + # Direct conversion from Ray Dataset + self._ensure_executed() + assert self._result_dataset is not None, ( + "Dataset should not be None after execution" + ) + self._result_arrow = self._result_dataset.to_pandas().to_arrow() + + # Handle validation if provided + if validation_reference: + try: + df = self._result_arrow.to_pandas() + validation_result = validation_reference.profile.validate(df) + if not validation_result.is_success: + raise ValidationFailed(validation_result) + except ImportError: + logger.warning("DQM profiler not available, skipping validation") + except Exception as e: + logger.error(f"Validation failed: {e}") + raise ValueError(f"Data validation failed: {e}") + + return self._result_arrow + + def to_remote_storage(self) -> list[str]: + """Write the result to remote storage.""" + if not self._config.batch_engine.staging_location: + raise ValueError("Staging location must be set for remote storage") + + try: + self._ensure_executed() + assert self._result_dataset is not None, ( + "Dataset should not be None after execution" + ) + output_uri = ( + f"{self._config.batch_engine.staging_location}/{str(uuid.uuid4())}" + ) + self._result_dataset.write_parquet(output_uri) + return [output_uri] + except Exception as e: + raise RuntimeError(f"Failed to write to remote storage: {e}") + + def persist( + self, + storage: SavedDatasetStorage, + allow_overwrite: bool = False, + timeout: Optional[int] = None, + ) -> str: + """Persist the result to the specified storage.""" + if not isinstance(storage, SavedDatasetFileStorage): + raise ValueError( + f"Ray compute engine only supports SavedDatasetFileStorage, got {type(storage)}" + ) + + destination_path = storage.file_options.uri + + # Check if destination already exists + if not destination_path.startswith(("s3://", "gs://", "hdfs://")): + import os + + if not allow_overwrite and os.path.exists(destination_path): + raise SavedDatasetLocationAlreadyExists(location=destination_path) + os.makedirs(os.path.dirname(destination_path), exist_ok=True) + + try: + self._ensure_executed() + assert self._result_dataset is not None, ( + "Dataset should not be None after execution" + ) + self._result_dataset.write_parquet(destination_path) + return destination_path + except Exception as e: + raise RuntimeError(f"Failed to persist dataset to {destination_path}: {e}") + + def to_sql(self) -> str: + """Generate SQL representation of the execution plan.""" + if self._plan and self._context: + return self._plan.to_sql(self._context) + raise NotImplementedError("SQL generation not available without execution plan") + + @property + def full_feature_names(self) -> bool: + return self._full_feature_names + + @property + def on_demand_feature_views(self) -> List[OnDemandFeatureView]: + return self._on_demand_feature_views + + @property + def metadata(self) -> Optional[RetrievalMetadata]: + return self._metadata + + def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: + """Internal method to get DataFrame (used by parent class).""" + self._ensure_executed() + assert self._result_dataset is not None, ( + "Dataset should not be None after execution" + ) + return self._result_dataset.to_pandas() + + def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table: + """Internal method to get Arrow Table (used by parent class).""" + self._ensure_executed() + assert self._result_dataset is not None, ( + "Dataset should not be None after execution" + ) + return self._result_dataset.to_pandas().to_arrow() + + +@dataclass +class RayMaterializationJob(MaterializationJob): + """ + Ray-based materialization job that tracks the status of feature materialization. + """ + + def __init__( + self, + job_id: str, + status: MaterializationJobStatus, + result: Optional[DAGValue] = None, + error: Optional[BaseException] = None, + ): + super().__init__() + self._job_id = job_id + self._status = status + self._result = result + self._error = error + + def job_id(self) -> str: + return self._job_id + + def status(self) -> MaterializationJobStatus: + return self._status + + def error(self) -> Optional[BaseException]: + return self._error + + def should_be_retried(self) -> bool: + """Ray jobs are generally not retried by default.""" + return False + + def url(self) -> Optional[str]: + """Ray jobs don't have a specific URL.""" + return None + + def result(self) -> Optional[DAGValue]: + """Get the result of the materialization job.""" + return self._result diff --git a/sdk/python/feast/infra/compute_engines/ray/nodes.py b/sdk/python/feast/infra/compute_engines/ray/nodes.py new file mode 100644 index 00000000000..17c82fcc6d6 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/ray/nodes.py @@ -0,0 +1,660 @@ +import logging +from datetime import datetime, timedelta +from typing import List, Optional, Union + +import pandas as pd +import ray +from ray.data import Dataset + +from feast import BatchFeatureView, StreamFeatureView +from feast.aggregation import Aggregation +from feast.data_source import DataSource +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.model import DAGFormat +from feast.infra.compute_engines.dag.node import DAGNode +from feast.infra.compute_engines.dag.value import DAGValue +from feast.infra.compute_engines.ray.config import RayComputeEngineConfig +from feast.infra.compute_engines.utils import create_offline_store_retrieval_job +from feast.infra.ray_shared_utils import ( + apply_field_mapping, + broadcast_join, + distributed_windowed_join, +) + +logger = logging.getLogger(__name__) + +# Entity timestamp alias for historical feature retrieval +ENTITY_TS_ALIAS = "__entity_event_timestamp" + + +class RayReadNode(DAGNode): + """ + Ray node for reading data from offline stores. + """ + + def __init__( + self, + name: str, + source: DataSource, + column_info, + config: RayComputeEngineConfig, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + ): + super().__init__(name) + self.source = source + self.column_info = column_info + self.config = config + self.start_time = start_time + self.end_time = end_time + + def execute(self, context: ExecutionContext) -> DAGValue: + """Execute the read operation to load data from the offline store.""" + try: + # Use utility function to create retrieval job + retrieval_job = create_offline_store_retrieval_job( + data_source=self.source, + column_info=self.column_info, + context=context, + start_time=self.start_time, + end_time=self.end_time, + ) + + # Convert to Ray Dataset + if hasattr(retrieval_job, "to_ray_dataset"): + # If the retrieval job supports Ray datasets directly + ray_dataset = retrieval_job.to_ray_dataset() + else: + # Fall back to converting from Arrow/Pandas + try: + arrow_table = retrieval_job.to_arrow() + ray_dataset = ray.data.from_arrow(arrow_table) + except Exception: + # Ultimate fallback to pandas + df = retrieval_job.to_df() + ray_dataset = ray.data.from_pandas(df) + + # Apply field mapping if needed + field_mapping = getattr(self.source, "field_mapping", None) + if field_mapping: + ray_dataset = apply_field_mapping(ray_dataset, field_mapping) + + return DAGValue( + data=ray_dataset, + format=DAGFormat.RAY, + metadata={ + "source": "offline_store", + "source_type": type(self.source).__name__, + "start_time": self.start_time, + "end_time": self.end_time, + }, + ) + + except Exception as e: + logger.error(f"Ray read node failed: {e}") + raise + + +class RayJoinNode(DAGNode): + """ + Ray node for joining entity dataframes with feature data. + """ + + def __init__( + self, + name: str, + column_info, + config: RayComputeEngineConfig, + is_historical_retrieval: bool = False, + ): + super().__init__(name) + self.column_info = column_info + self.config = config + self.is_historical_retrieval = is_historical_retrieval + + def execute(self, context: ExecutionContext) -> DAGValue: + """Execute the join operation.""" + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.RAY) + feature_dataset: Dataset = input_value.data + + # If this is not a historical retrieval, just return the feature data + if not self.is_historical_retrieval or context.entity_df is None: + return DAGValue( + data=feature_dataset, + format=DAGFormat.RAY, + metadata={"joined": False}, + ) + + # Convert entity_df to Ray Dataset + entity_df = context.entity_df + if isinstance(entity_df, pd.DataFrame): + entity_dataset = ray.data.from_pandas(entity_df) + else: + entity_dataset = entity_df + + # Perform the join using Ray operations + join_keys = self.column_info.join_keys + timestamp_col = self.column_info.timestamp_column + requested_feats = getattr(self.column_info, "feature_cols", []) + + # Check if the feature dataset contains aggregated features (from aggregation node) + # If so, we don't need point-in-time join logic - just simple join on entity keys + sample_data = feature_dataset.take(1) + is_aggregated = False + if sample_data: + if hasattr(sample_data[0], "columns"): + feature_cols = sample_data[0].columns.tolist() + else: + # Handle other data formats + feature_cols = ( + list(sample_data[0].keys()) + if isinstance(sample_data[0], dict) + else [] + ) + + # Check for aggregated feature column patterns + is_aggregated = any( + col.startswith( + ("sum_", "avg_", "mean_", "count_", "min_", "max_", "std_", "var_") + ) + for col in feature_cols + ) + + feature_size = feature_dataset.size_bytes() + + if is_aggregated: + # For aggregated features, do simple join on entity keys + feature_df = feature_dataset.to_pandas() + feature_ref = ray.put(feature_df) + + def join_with_aggregated_features(batch: pd.DataFrame) -> pd.DataFrame: + if batch.empty: + return batch + features = ray.get(feature_ref) + if join_keys: + result = pd.merge( + batch, + features, + on=join_keys, + how="left", + suffixes=("", "_feature"), + ) + else: + result = batch.copy() + return result + + joined_dataset = entity_dataset.map_batches( + join_with_aggregated_features, batch_format="pandas" + ) + else: + if feature_size <= self.config.broadcast_join_threshold_mb * 1024 * 1024: + # Use broadcast join for small feature datasets + joined_dataset = broadcast_join( + entity_dataset, + feature_dataset.to_pandas(), + join_keys, + timestamp_col, + requested_feats, + ) + else: + # Use distributed join for large datasets + joined_dataset = distributed_windowed_join( + entity_dataset, + feature_dataset, + join_keys, + timestamp_col, + requested_feats, + ) + + return DAGValue( + data=joined_dataset, + format=DAGFormat.RAY, + metadata={ + "joined": True, + "join_keys": join_keys, + "join_strategy": "broadcast" + if feature_size <= self.config.broadcast_join_threshold_mb * 1024 * 1024 + else "distributed", + }, + ) + + +class RayFilterNode(DAGNode): + """ + Ray node for filtering data based on TTL and custom conditions. + """ + + def __init__( + self, + name: str, + column_info, + config: RayComputeEngineConfig, + ttl: Optional[timedelta] = None, + filter_condition: Optional[str] = None, + ): + super().__init__(name) + self.column_info = column_info + self.config = config + self.ttl = ttl + self.filter_condition = filter_condition + + def execute(self, context: ExecutionContext) -> DAGValue: + """Execute the filter operation.""" + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.RAY) + dataset: Dataset = input_value.data + + def apply_filters(batch: pd.DataFrame) -> pd.DataFrame: + """Apply TTL and custom filters to the batch.""" + if batch.empty: + return batch + + filtered_batch = batch.copy() + + # Apply TTL filter if specified + if self.ttl: + timestamp_col = self.column_info.timestamp_column + if timestamp_col in filtered_batch.columns: + # Import necessary modules at the top of the function + from datetime import timezone + + import pandas as pd + + # Convert to datetime if not already + if not pd.api.types.is_datetime64_any_dtype( + filtered_batch[timestamp_col] + ): + filtered_batch[timestamp_col] = pd.to_datetime( + filtered_batch[timestamp_col] + ) + + # For historical retrieval, use entity timestamp for TTL calculation + if ENTITY_TS_ALIAS in filtered_batch.columns: + # Use entity timestamp for TTL calculation (historical retrieval) + if not pd.api.types.is_datetime64_any_dtype( + filtered_batch[ENTITY_TS_ALIAS] + ): + filtered_batch[ENTITY_TS_ALIAS] = pd.to_datetime( + filtered_batch[ENTITY_TS_ALIAS] + ) + + # Apply TTL filter with both upper and lower bounds: + # 1. feature.ts <= entity.event_timestamp (upper bound) + # 2. feature.ts >= entity.event_timestamp - ttl (lower bound) + upper_bound = filtered_batch[ENTITY_TS_ALIAS] + lower_bound = filtered_batch[ENTITY_TS_ALIAS] - self.ttl + + filtered_batch = filtered_batch[ + (filtered_batch[timestamp_col] <= upper_bound) + & (filtered_batch[timestamp_col] >= lower_bound) + ] + else: + # Use current time for TTL calculation (real-time retrieval) + # Check if timestamp column is timezone-aware + if pd.api.types.is_datetime64tz_dtype( + filtered_batch[timestamp_col] + ): + # Use timezone-aware current time + current_time = datetime.now(timezone.utc) + else: + # Use naive datetime + current_time = datetime.now() + + ttl_threshold = current_time - self.ttl + + # Apply TTL filter + filtered_batch = filtered_batch[ + filtered_batch[timestamp_col] >= ttl_threshold + ] + + # Apply custom filter condition if specified + if self.filter_condition: + try: + filtered_batch = filtered_batch.query(self.filter_condition) + except Exception as e: + logger.warning(f"Custom filter failed: {e}") + + return filtered_batch + + filtered_dataset = dataset.map_batches(apply_filters, batch_format="pandas") + + return DAGValue( + data=filtered_dataset, + format=DAGFormat.RAY, + metadata={ + "filtered": True, + "ttl": self.ttl, + "filter_condition": self.filter_condition, + }, + ) + + +class RayAggregationNode(DAGNode): + """ + Ray node for performing aggregations on feature data. + """ + + def __init__( + self, + name: str, + aggregations: List[Aggregation], + group_by_keys: List[str], + timestamp_col: str, + config: RayComputeEngineConfig, + ): + super().__init__(name) + self.aggregations = aggregations + self.group_by_keys = group_by_keys + self.timestamp_col = timestamp_col + self.config = config + + def execute(self, context: ExecutionContext) -> DAGValue: + """Execute the aggregation operation.""" + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.RAY) + dataset: Dataset = input_value.data + + # Convert aggregations to Ray's groupby format + agg_dict = {} + for agg in self.aggregations: + feature_name = f"{agg.function}_{agg.column}" + if agg.time_window: + feature_name += f"_{int(agg.time_window.total_seconds())}s" + + if agg.function == "count": + agg_dict[feature_name] = (agg.column, "count") + elif agg.function == "sum": + agg_dict[feature_name] = (agg.column, "sum") + elif agg.function == "mean" or agg.function == "avg": + agg_dict[feature_name] = (agg.column, "mean") + elif agg.function == "min": + agg_dict[feature_name] = (agg.column, "min") + elif agg.function == "max": + agg_dict[feature_name] = (agg.column, "max") + elif agg.function == "std": + agg_dict[feature_name] = (agg.column, "std") + elif agg.function == "var": + agg_dict[feature_name] = (agg.column, "var") + else: + logger.warning(f"Unknown aggregation function: {agg.function}") + continue + + # Apply aggregations using pandas fallback (Ray's native groupby has compatibility issues) + if self.group_by_keys and agg_dict: + # Use pandas-based aggregation for entire dataset + aggregated_dataset = self._fallback_pandas_aggregation(dataset, agg_dict) + else: + # No group keys or aggregations, return original dataset + aggregated_dataset = dataset + + return DAGValue( + data=aggregated_dataset, + format=DAGFormat.RAY, + metadata={ + "aggregated": True, + "aggregations": len(self.aggregations), + "group_by_keys": self.group_by_keys, + }, + ) + + def _fallback_pandas_aggregation(self, dataset: Dataset, agg_dict: dict) -> Dataset: + """Fallback to pandas-based aggregation for the entire dataset.""" + # Convert entire dataset to pandas for aggregation + df = dataset.to_pandas() + + if df.empty: + return dataset + + # Group by the specified keys + if self.group_by_keys: + grouped = df.groupby(self.group_by_keys) + else: + # If no group keys, apply aggregations to entire dataset + grouped = df.groupby(lambda x: 0) # Dummy grouping + + # Apply each aggregation + agg_results = [] + for feature_name, (column, function) in agg_dict.items(): + if column in df.columns: + if function == "count": + result = grouped[column].count() + elif function == "sum": + result = grouped[column].sum() + elif function == "mean": + result = grouped[column].mean() + elif function == "min": + result = grouped[column].min() + elif function == "max": + result = grouped[column].max() + elif function == "std": + result = grouped[column].std() + elif function == "var": + result = grouped[column].var() + else: + logger.warning(f"Unknown aggregation function: {function}") + continue + + result.name = feature_name + agg_results.append(result) + + # Combine aggregation results + if agg_results: + result_df = pd.concat(agg_results, axis=1) + + # Reset index to make group keys regular columns + if self.group_by_keys: + result_df = result_df.reset_index() + + # Convert back to Ray Dataset + return ray.data.from_pandas(result_df) + else: + return dataset + + +class RayDedupNode(DAGNode): + """ + Ray node for deduplicating records. + """ + + def __init__( + self, + name: str, + column_info, + config: RayComputeEngineConfig, + ): + super().__init__(name) + self.column_info = column_info + self.config = config + + def execute(self, context: ExecutionContext) -> DAGValue: + """Execute the deduplication operation.""" + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.RAY) + dataset: Dataset = input_value.data + + def deduplicate_batch(batch: pd.DataFrame) -> pd.DataFrame: + """Remove duplicates from the batch.""" + if batch.empty: + return batch + + # Get deduplication keys + join_keys = self.column_info.join_keys + timestamp_col = self.column_info.timestamp_column + + if join_keys: + # Sort by join keys and timestamp (most recent first) + sort_columns = join_keys + [timestamp_col] + available_columns = [ + col for col in sort_columns if col in batch.columns + ] + + if available_columns: + # Sort and deduplicate + sorted_batch = batch.sort_values( + available_columns, + ascending=[True] * len(join_keys) + + [False], # Recent timestamps first + ) + + # Keep first occurrence (most recent) for each join key combination + deduped_batch = sorted_batch.drop_duplicates( + subset=join_keys, + keep="first", + ) + + return deduped_batch + + return batch + + deduped_dataset = dataset.map_batches(deduplicate_batch, batch_format="pandas") + + return DAGValue( + data=deduped_dataset, + format=DAGFormat.RAY, + metadata={"deduped": True}, + ) + + +class RayTransformationNode(DAGNode): + """ + Ray node for applying feature transformations. + """ + + def __init__( + self, + name: str, + transformation, + config: RayComputeEngineConfig, + ): + super().__init__(name) + # Extract the UDF function to avoid serialization issues with PandasTransformation + if hasattr(transformation, "udf") and callable(transformation.udf): + self.transformation_udf = transformation.udf + self.transformation_name = getattr(transformation, "name", "unknown") + elif callable(transformation): + # Handle direct UDF functions + self.transformation_udf = transformation + self.transformation_name = getattr(transformation, "__name__", "unknown") + else: + self.transformation_udf = None + self.transformation_name = "unknown" + self.config = config + + def execute(self, context: ExecutionContext) -> DAGValue: + """Execute the transformation operation.""" + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.RAY) + dataset: Dataset = input_value.data + + # Use the extracted UDF function directly + transformation_func = self.transformation_udf + + def apply_transformation(batch: pd.DataFrame) -> pd.DataFrame: + """Apply the transformation to the batch.""" + if batch.empty: + return batch + + try: + # Apply the transformation function directly + if transformation_func and callable(transformation_func): + transformed_batch = transformation_func(batch) + else: + logger.warning( + "Transformation function not available, returning original batch" + ) + transformed_batch = batch + + return transformed_batch + except Exception as e: + logger.error(f"Transformation failed: {e}") + return batch + + transformed_dataset = dataset.map_batches( + apply_transformation, batch_format="pandas" + ) + + return DAGValue( + data=transformed_dataset, + format=DAGFormat.RAY, + metadata={ + "transformed": True, + "transformation": self.transformation_name, + }, + ) + + +class RayWriteNode(DAGNode): + """ + Ray node for writing results to online/offline stores. + """ + + def __init__( + self, + name: str, + feature_view: Union[BatchFeatureView, StreamFeatureView], + config: RayComputeEngineConfig, + ): + super().__init__(name) + self.feature_view = feature_view + self.config = config + + def execute(self, context: ExecutionContext) -> DAGValue: + """Execute the write operation.""" + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.RAY) + dataset: Dataset = input_value.data + + def write_batch(batch: pd.DataFrame) -> pd.DataFrame: + """Write each batch to the appropriate stores.""" + if batch.empty: + return batch + + try: + # Convert to Arrow Table for writing + import pyarrow as pa + + arrow_table = pa.Table.from_pandas(batch) + + # Write to online store if enabled + if getattr(self.feature_view, "online", False): + # TODO: Implement proper online store writing with correct data format conversion + logger.debug( + f"Online store writing not implemented yet for {len(batch)} rows" + ) + + # Write to offline store if enabled + if getattr(self.feature_view, "offline", False): + try: + context.offline_store.offline_write_batch( + config=context.repo_config, + feature_view=self.feature_view, + table=arrow_table, + progress=lambda x: None, + ) + logger.debug(f"Wrote {len(batch)} rows to offline store") + except Exception as e: + logger.error(f"Failed to write to offline store: {e}") + + return batch + + except Exception as e: + logger.error(f"Write operation failed: {e}") + return batch + + # Apply write operation to all batches + written_dataset = dataset.map_batches(write_batch, batch_format="pandas") + + # Materialize the dataset to ensure writes are executed + written_dataset = written_dataset.materialize() + + return DAGValue( + data=written_dataset, + format=DAGFormat.RAY, + metadata={ + "written": True, + "feature_view": self.feature_view.name, + "online": getattr(self.feature_view, "online", False), + "offline": getattr(self.feature_view, "offline", False), + }, + ) diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py index b41d6e45a8b..1e0ef944469 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py @@ -3,7 +3,7 @@ import uuid from datetime import datetime from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import fsspec import numpy as np @@ -21,19 +21,27 @@ ) from feast.feature_logging import LoggingConfig, LoggingSource from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView -from feast.infra.offline_stores.file_source import FileSource, SavedDatasetFileStorage +from feast.infra.offline_stores.file_source import ( + FileLoggingDestination, + FileSource, + SavedDatasetFileStorage, +) from feast.infra.offline_stores.offline_store import ( OfflineStore, RetrievalJob, RetrievalMetadata, ) from feast.infra.offline_stores.offline_utils import ( - assert_expected_columns_in_entity_df, get_entity_df_timestamp_bounds, - get_expected_join_keys, get_pyarrow_schema_from_batch_source, infer_event_timestamp_from_entity_df, ) +from feast.infra.ray_shared_utils import ( + _build_required_columns, + apply_field_mapping, + ensure_timestamp_compatibility, + normalize_timestamp_columns, +) from feast.infra.registry.base_registry import BaseRegistry from feast.on_demand_feature_view import OnDemandFeatureView from feast.repo_config import FeastConfigBaseModel, RepoConfig @@ -57,11 +65,9 @@ def _get_data_schema_info( """ if isinstance(data, Dataset): schema = data.schema() - # Embed _create_dtypes_dict_from_schema logic inline dtypes = {} for i, col in enumerate(schema.names): field_type = schema.field(i).type - # Embed _pa_type_to_pandas_dtype logic inline try: pa_type_str = str(field_type).lower() feast_value_type = pa_to_feast_value_type(pa_type_str) @@ -98,119 +104,6 @@ def _apply_to_data( return process_func(data) -def _normalize_timestamp_columns( - data: Union[pd.DataFrame, Dataset], - columns: Union[str, List[str]], - inplace: bool = False, -) -> Union[pd.DataFrame, Dataset]: - """ - Normalize timestamp columns to UTC with second precision. - Works with both pandas DataFrames and Ray Datasets. - Args: - data: DataFrame or Ray Dataset containing the timestamp columns - columns: Column name (str) or list of column names (List[str]) to normalize - inplace: Whether to modify the DataFrame in place (only applies to pandas) - Returns: - DataFrame or Dataset with normalized timestamp columns - """ - # Normalize input to always be a list - column_list = [columns] if isinstance(columns, str) else columns - - def apply_normalization(series: pd.Series) -> pd.Series: - return ( - pd.to_datetime(series, utc=True, errors="coerce") - .dt.floor("s") - .astype("datetime64[ns, UTC]") - ) - - if isinstance(data, Dataset): - - def normalize_batch(batch: pd.DataFrame) -> pd.DataFrame: - for column in column_list: - if not batch.empty and column in batch.columns: - batch[column] = apply_normalization(batch[column]) - return batch - - return data.map_batches(normalize_batch, batch_format="pandas") - else: - if not inplace: - data = data.copy() - - for column in column_list: - if column in data.columns: - data[column] = apply_normalization(data[column]) - return data - - -def _ensure_timestamp_compatibility( - data: Union[pd.DataFrame, Dataset], - timestamp_fields: List[str], - inplace: bool = False, -) -> Union[pd.DataFrame, Dataset]: - """ - Ensure timestamp columns have compatible dtypes and precision for joins. - Works with both pandas DataFrames and Ray Datasets. - Args: - data: DataFrame or Ray Dataset to process - timestamp_fields: List of timestamp field names - inplace: Whether to modify the DataFrame in place (only applies to pandas) - Returns: - DataFrame or Dataset with compatible timestamp columns - """ - if isinstance(data, Dataset): - # Ray Dataset path - def ensure_compatibility(batch: pd.DataFrame) -> pd.DataFrame: - # Use existing utility for timezone awareness - batch = make_df_tzaware(batch) - - # Then normalize timestamp precision for specified fields only - for field in timestamp_fields: - if field in batch.columns: - batch[field] = ( - pd.to_datetime(batch[field], utc=True, errors="coerce") - .dt.floor("s") - .astype("datetime64[ns, UTC]") - ) - return batch - - return data.map_batches(ensure_compatibility, batch_format="pandas") - else: - # Pandas DataFrame path - if not inplace: - data = data.copy() - - # Use existing utility for timezone awareness - data = make_df_tzaware(data) - - # Then normalize timestamp precision for specified fields only - for field in timestamp_fields: - if field in data.columns: - data = _normalize_timestamp_columns(data, field, inplace=True) - return data - - -def _build_required_columns( - join_key_columns: List[str], - feature_name_columns: List[str], - timestamp_columns: List[str], -) -> List[str]: - """ - Build list of required columns for data processing. - Args: - join_key_columns: List of join key columns - feature_name_columns: List of feature columns - timestamp_columns: List of timestamp columns - Returns: - List of all required columns - """ - all_required_columns = join_key_columns + feature_name_columns + timestamp_columns - if not join_key_columns: - all_required_columns.append(DUMMY_ENTITY_ID) - if "event_timestamp" not in all_required_columns: - all_required_columns.append("event_timestamp") - return all_required_columns - - def _handle_empty_dataframe_case( join_key_columns: List[str], feature_name_columns: List[str], @@ -271,37 +164,29 @@ def _safe_get_entity_timestamp_bounds( """ try: if isinstance(data, Dataset): - # Ray Dataset path - try Ray's built-in operations first min_ts = data.min(timestamp_column) max_ts = data.max(timestamp_column) else: - # Pandas DataFrame path if timestamp_column in data.columns: min_ts, max_ts = get_entity_df_timestamp_bounds(data, timestamp_column) else: return None, None - - # Convert to datetime if needed if hasattr(min_ts, "to_pydatetime"): min_ts = min_ts.to_pydatetime() elif isinstance(min_ts, pd.Timestamp): min_ts = min_ts.to_pydatetime() - if hasattr(max_ts, "to_pydatetime"): max_ts = max_ts.to_pydatetime() elif isinstance(max_ts, pd.Timestamp): max_ts = max_ts.to_pydatetime() - return min_ts, max_ts except Exception as e: logger.debug( f"Timestamp bounds extraction failed: {e}, falling back to manual calculation" ) - - # Fallback to manual calculation try: if isinstance(data, Dataset): - # Ray Dataset fallback + def extract_bounds(batch: pd.DataFrame) -> pd.DataFrame: if timestamp_column in batch.columns and not batch.empty: timestamps = pd.to_datetime(batch[timestamp_column], utc=True) @@ -320,7 +205,6 @@ def extract_bounds(batch: pd.DataFrame) -> pd.DataFrame: if pd.notna(min_ts) and pd.notna(max_ts): return min_ts.to_pydatetime(), max_ts.to_pydatetime() else: - # Pandas DataFrame fallback if timestamp_column in data.columns: timestamps = pd.to_datetime(data[timestamp_column], utc=True) return ( @@ -333,48 +217,6 @@ def extract_bounds(batch: pd.DataFrame) -> pd.DataFrame: return None, None -def _safe_validate_entity_dataframe( - data: Union[pd.DataFrame, Dataset], - feature_views: List[FeatureView], - project: str, - registry: BaseRegistry, -) -> None: - """ - Safely validate entity DataFrame or Dataset. - Works with both pandas DataFrames and Ray Datasets. - Args: - data: DataFrame or Ray Dataset to validate - feature_views: List of feature views to validate against - project: Feast project name - registry: Feature registry - """ - try: - # Get expected join keys for validation - expected_join_keys = get_expected_join_keys(project, feature_views, registry) - - dtypes, columns = _get_data_schema_info(data) - - # Infer event timestamp column - timestamp_col = infer_event_timestamp_from_entity_df(dtypes) - - # Validate DataFrame/Dataset has required columns - assert_expected_columns_in_entity_df(dtypes, expected_join_keys, timestamp_col) - - data_type = "Dataset" if isinstance(data, Dataset) else "DataFrame" - logger.info( - f"Entity {data_type} validation passed:\n" - f" Expected join keys: {expected_join_keys}\n" - f" Detected timestamp column: {timestamp_col}\n" - f" Available columns: {columns}" - ) - - except Exception as e: - # Log validation issues but don't fail - data_type = "Dataset" if isinstance(data, Dataset) else "DataFrame" - logger.warning(f"Entity {data_type} validation skipped due to error: {e}") - logger.debug("Validation error details:", exc_info=True) - - def _safe_validate_schema( config: RepoConfig, data_source: DataSource, @@ -395,15 +237,12 @@ def _safe_validate_schema( expected_schema, expected_columns = get_pyarrow_schema_from_batch_source( config, data_source ) - if set(expected_columns) != set(table_columns): logger.warning( f"Schema mismatch in {operation_name}:\n" f" Expected columns: {expected_columns}\n" f" Actual columns: {table_columns}" ) - - # Check if it's just a column order issue if set(expected_columns) == set(table_columns): logger.info(f"Columns match but order differs for {operation_name}") return expected_schema, expected_columns @@ -416,7 +255,6 @@ def _safe_validate_schema( f"Schema validation skipped for {operation_name} due to error: {e}" ) logger.debug("Schema validation error details:", exc_info=True) - return None @@ -439,34 +277,24 @@ def convert_batch(batch: pd.DataFrame) -> pd.DataFrame: for fv in feature_views: for feature in fv.features: feat_name = feature.name - - # Check if this feature exists in the batch if feat_name not in batch.columns: continue - try: - # Get the Feast ValueType for this feature value_type = feature.dtype.to_value_type() - - # Handle array/list types if value_type.name.endswith("_LIST"): batch[feat_name] = _convert_array_column( batch[feat_name], value_type ) else: - # Handle scalar types using feast type mapping target_pandas_type = feast_value_type_to_pandas_type(value_type) batch[feat_name] = _convert_scalar_column( batch[feat_name], value_type, target_pandas_type ) - except Exception as e: logger.warning( f"Failed to convert feature {feat_name} to proper type: {e}" ) - # Keep original dtype if conversion fails continue - return batch return _apply_to_data(data, convert_batch) @@ -489,13 +317,11 @@ def _convert_scalar_column( elif value_type == ValueType.UNIX_TIMESTAMP: return pd.to_datetime(series, unit="s", errors="coerce") else: - # For other types, use pandas default conversion return series.astype(target_pandas_type) def _convert_array_column(series: pd.Series, value_type: ValueType) -> pd.Series: """Convert an array feature column to the appropriate type with proper empty array handling.""" - # Determine the base type for array elements base_type_map = { ValueType.INT32_LIST: np.int32, ValueType.INT64_LIST: np.int64, @@ -511,37 +337,16 @@ def _convert_array_column(series: pd.Series, value_type: ValueType) -> pd.Series def convert_array_item(item): if item is None or (isinstance(item, list) and len(item) == 0): - # Return properly typed empty array if target_dtype == object: return np.array([], dtype=object) else: return np.array([], dtype=target_dtype) else: - # Return the item as-is for non-empty arrays return item return series.apply(convert_array_item) -def _apply_field_mapping( - data: Union[pd.DataFrame, Dataset], field_mapping: Dict[str, str] -) -> Union[pd.DataFrame, Dataset]: - """ - Apply field mapping to column names. - Works with both pandas DataFrames and Ray Datasets. - Args: - data: DataFrame or Ray Dataset to apply mapping to - field_mapping: Dictionary mapping old column names to new column names - Returns: - DataFrame or Dataset with renamed columns - """ - - def rename_columns(df: pd.DataFrame) -> pd.DataFrame: - return df.rename(columns=field_mapping) - - return _apply_to_data(data, rename_columns) - - class RayOfflineStoreConfig(FeastConfigBaseModel): """ Configuration for the Ray Offline Store. @@ -561,6 +366,12 @@ class RayOfflineStoreConfig(FeastConfigBaseModel): target_partition_size_mb: Optional[int] = 64 window_size_for_joins: Optional[str] = "1H" + # Logging settings + enable_ray_logging: Optional[bool] = False + + # Ray configuration for resource management (memory, CPU limits) + ray_conf: Optional[Dict[str, Any]] = None + class RayResourceManager: """ @@ -598,10 +409,17 @@ def configure_ray_context(self) -> None: ctx.max_parallelism = self.available_cpus * multiplier ctx.shuffle_strategy = "sort" # type: ignore ctx.enable_tensor_extension_casting = False - logger.info( - f"Configured Ray context: {self.available_cpus} CPUs, " - f"{self.available_memory // 1024**3}GB memory, {self.num_nodes} nodes" - ) + + if not getattr(self.config, "enable_ray_logging", False): + ctx.enable_progress_bars = False + if hasattr(ctx, "verbose_progress"): + ctx.verbose_progress = False + + if getattr(self.config, "enable_ray_logging", False): + logger.info( + f"Configured Ray context: {self.available_cpus} CPUs, " + f"{self.available_memory // 1024**3}GB memory, {self.num_nodes} nodes" + ) def estimate_optimal_partitions(self, dataset_size_bytes: int) -> int: """ @@ -712,7 +530,6 @@ def _manual_point_in_time_join( if is_list_feature: result[feat] = [[] for _ in range(len(result))] else: - # Check if the feature column is datetime if feat in features_df.columns and pd.api.types.is_datetime64_any_dtype( features_df[feat] ): @@ -739,19 +556,13 @@ def _manual_point_in_time_join( entity_matches &= pd.Series( [False] * len(features_df), index=features_df.index ) - if not entity_matches.any(): continue - matching_features = features_df[entity_matches] - - # Apply time filter if timestamp field exists entity_timestamp = entity_row[timestamp_field] if timestamp_field in matching_features.columns: time_matches = matching_features[timestamp_field] <= entity_timestamp matching_features = matching_features[time_matches] - - # Skip if no features match entity criteria or time criteria if matching_features.empty: continue @@ -800,37 +611,32 @@ def join_batch_with_features(batch: pd.DataFrame) -> pd.DataFrame: """Join a batch with broadcast feature data.""" features = ray.get(feature_ref) - logger.info( - f"Processing feature view {feature_view_name} with join keys {join_keys}" + enable_logging = getattr( + self.resource_manager.config, "enable_ray_logging", False ) + if enable_logging: + logger.info( + f"Processing feature view {feature_view_name} with join keys {join_keys}" + ) - # Determine feature join keys - # For entity mapping (join key mapping), original_join_keys contains the original feature view join keys - # join_keys contains the mapped entity join keys if original_join_keys: - # Entity mapping case: entity has join_keys, features have original_join_keys feature_join_keys = original_join_keys entity_join_keys = join_keys else: - # Normal case: both use the same join keys feature_join_keys = join_keys entity_join_keys = join_keys - # Select only required feature columns plus join keys and timestamp feature_cols = [timestamp_field] + feature_join_keys + requested_feats - # Only include columns that actually exist in the features DataFrame available_feature_cols = [ col for col in feature_cols if col in features.columns ] - # Ensure we have the minimum required columns if timestamp_field not in available_feature_cols: raise ValueError( f"Timestamp field '{timestamp_field}' not found in features columns: {list(features.columns)}" ) - # Check if required feature columns exist missing_feats = [ feat for feat in requested_feats if feat not in features.columns ] @@ -841,14 +647,12 @@ def join_batch_with_features(batch: pd.DataFrame) -> pd.DataFrame: features_filtered = features[available_feature_cols].copy() - # Ensure timestamp columns have compatible dtypes and precision - batch = _normalize_timestamp_columns(batch, timestamp_field, inplace=True) - features_filtered = _normalize_timestamp_columns( + batch = normalize_timestamp_columns(batch, timestamp_field, inplace=True) + features_filtered = normalize_timestamp_columns( features_filtered, timestamp_field, inplace=True ) if not entity_join_keys: - # Temporal join without entity keys batch_sorted = batch.sort_values(timestamp_field).reset_index(drop=True) features_sorted = features_filtered.sort_values( timestamp_field @@ -860,32 +664,20 @@ def join_batch_with_features(batch: pd.DataFrame) -> pd.DataFrame: direction="backward", ) else: - # Ensure entity join keys exist in batch for key in entity_join_keys: if key not in batch.columns: batch[key] = np.nan - - # Ensure feature join keys exist in features for key in feature_join_keys: if key not in features_filtered.columns: features_filtered[key] = np.nan - - # Drop rows with NaN values in join keys or timestamp batch_clean = batch.dropna( subset=entity_join_keys + [timestamp_field] ).copy() features_clean = features_filtered.dropna( subset=feature_join_keys + [timestamp_field] ).copy() - - # If no valid data remains, return empty result if batch_clean.empty or features_clean.empty: - return batch.head(0) # Return empty dataframe with same columns - - # Sort both DataFrames for merge_asof requirements - # merge_asof requires: left sorted by 'on' column, right sorted by ['by'] + ['on'] columns - - # For the left DataFrame (batch), sort by timestamp (on column) + return batch.head(0) if timestamp_field in batch_clean.columns: batch_sorted = batch_clean.sort_values( timestamp_field, ascending=True @@ -893,21 +685,13 @@ def join_batch_with_features(batch: pd.DataFrame) -> pd.DataFrame: else: batch_sorted = batch_clean.reset_index(drop=True) - # For the right DataFrame (features), sort by join keys (by columns) + timestamp (on column) right_sort_columns = [] - - # Add join keys to sort columns (these are the 'by' columns for merge_asof) for key in feature_join_keys: if key in features_clean.columns: right_sort_columns.append(key) - - # Add timestamp field to sort columns (this is the 'on' column for merge_asof) if timestamp_field in features_clean.columns: right_sort_columns.append(timestamp_field) - - # Sort the right DataFrame if right_sort_columns: - # Remove duplicates first, then sort features_clean = features_clean.drop_duplicates( subset=right_sort_columns, keep="last" ) @@ -917,35 +701,27 @@ def join_batch_with_features(batch: pd.DataFrame) -> pd.DataFrame: else: features_sorted = features_clean.reset_index(drop=True) - # Verify sorting for merge_asof if ( timestamp_field in features_sorted.columns and len(features_sorted) > 1 ): - # Check if timestamp is monotonic within each group if feature_join_keys: - # Group by join keys and check if timestamp is monotonic within each group grouped = features_sorted.groupby(feature_join_keys, sort=False) for name, group in grouped: if not group[timestamp_field].is_monotonic_increasing: - # If not monotonic, sort again more carefully features_sorted = features_sorted.sort_values( feature_join_keys + [timestamp_field], ascending=True, ).reset_index(drop=True) break else: - # No join keys, just check timestamp monotonicity if not features_sorted[timestamp_field].is_monotonic_increasing: features_sorted = features_sorted.sort_values( timestamp_field, ascending=True ).reset_index(drop=True) - # Attempt merge_asof with proper error handling try: - # Remove duplicates from both DataFrames before merge_asof if feature_join_keys: - # For batch DataFrame, remove duplicates based on join keys + timestamp batch_dedup_cols = [ k for k in entity_join_keys if k in batch_sorted.columns ] @@ -955,8 +731,6 @@ def join_batch_with_features(batch: pd.DataFrame) -> pd.DataFrame: batch_sorted = batch_sorted.drop_duplicates( subset=batch_dedup_cols, keep="last" ) - - # For features DataFrame, remove duplicates based on join keys + timestamp feature_dedup_cols = [ k for k in feature_join_keys if k in features_sorted.columns ] @@ -967,9 +741,7 @@ def join_batch_with_features(batch: pd.DataFrame) -> pd.DataFrame: subset=feature_dedup_cols, keep="last" ) - # Perform merge_asof if feature_join_keys: - # Handle join keys properly - if they are the same, just use one set if entity_join_keys == feature_join_keys: result = pd.merge_asof( batch_sorted, @@ -980,7 +752,6 @@ def join_batch_with_features(batch: pd.DataFrame) -> pd.DataFrame: suffixes=("", "_right"), ) else: - # Different join keys, use left_by and right_by parameters result = pd.merge_asof( batch_sorted, features_sorted, @@ -1000,10 +771,10 @@ def join_batch_with_features(batch: pd.DataFrame) -> pd.DataFrame: ) except Exception as e: - logger.warning( - f"merge_asof failed: {e}, implementing manual point-in-time join" - ) - # Fall back to manual join + if enable_logging: + logger.warning( + f"merge_asof didn't work: {e}, implementing manual point-in-time join" + ) result = self._manual_point_in_time_join( batch_clean, features_clean, @@ -1012,7 +783,6 @@ def join_batch_with_features(batch: pd.DataFrame) -> pd.DataFrame: timestamp_field, requested_feats, ) - # Handle feature renaming if full_feature_names is True if full_feature_names and feature_view_name: for feat in requested_feats: if feat in result.columns: @@ -1038,27 +808,18 @@ def windowed_temporal_join( ) -> Dataset: """Perform windowed temporal join for large datasets.""" - # Use configured window size if not provided window_size = window_size or ( self.resource_manager.config.window_size_for_joins or "1H" ) - - # Step 1: Optimize both datasets for joining entity_optimized = self.optimize_dataset_for_join(entity_ds, join_keys) feature_optimized = self.optimize_dataset_for_join(feature_ds, join_keys) - - # Step 2: Add time windows and data source markers entity_windowed = self._add_time_windows_and_source_marker( entity_optimized, timestamp_field, "entity", window_size ) feature_windowed = self._add_time_windows_and_source_marker( feature_optimized, timestamp_field, "feature", window_size ) - - # Step 3: Union datasets for co-processing combined_ds = entity_windowed.union(feature_windowed) - - # Step 4: Group by time window and join keys, then apply point-in-time logic result_ds = combined_ds.map_batches( self._apply_windowed_point_in_time_logic, batch_format="pandas", @@ -1107,23 +868,17 @@ def _apply_windowed_point_in_time_logic( if len(batch) == 0: return pd.DataFrame() - # Group by window and join keys to apply merge_asof result_chunks = [] group_keys = ["time_window"] + join_keys for group_values, group_data in batch.groupby(group_keys): - # Separate entity and feature data entity_data = group_data[group_data["_data_source"] == "entity"].copy() feature_data = group_data[group_data["_data_source"] == "feature"].copy() - if len(entity_data) > 0 and len(feature_data) > 0: - # Drop helper columns for merge_asof entity_clean = entity_data.drop(columns=["time_window", "_data_source"]) feature_clean = feature_data.drop( columns=["time_window", "_data_source"] ) - - # Apply merge_asof within the group if join_keys: merged = pd.merge_asof( entity_clean.sort_values(join_keys + [timestamp_field]), @@ -1142,7 +897,6 @@ def _apply_windowed_point_in_time_logic( result_chunks.append(merged) elif len(entity_data) > 0: - # No features found, return entity data with NaN features entity_clean = entity_data.drop(columns=["time_window", "_data_source"]) for feat in requested_feats: if feat not in entity_clean.columns: @@ -1151,8 +905,6 @@ def _apply_windowed_point_in_time_logic( if result_chunks: result = pd.concat(result_chunks, ignore_index=True) - - # Handle feature renaming if full_feature_names is True if full_feature_names and feature_view_name: for feat in requested_feats: if feat in result.columns: @@ -1172,9 +924,11 @@ def __init__( Dataset, pd.DataFrame, Callable[[], Union[Dataset, pd.DataFrame]] ], staging_location: Optional[str] = None, + config: Optional[RayOfflineStoreConfig] = None, ): self._dataset_or_callable = dataset_or_callable self._staging_location = staging_location + self._config = config or RayOfflineStoreConfig() self._cached_df: Optional[pd.DataFrame] = None self._cached_dataset: Optional[Dataset] = None self._metadata: Optional[RetrievalMetadata] = None @@ -1182,12 +936,11 @@ def __init__( self._on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None self._feature_refs: List[str] = [] self._entity_df: Optional[pd.DataFrame] = None - self._prefer_ray_datasets: bool = True # New flag to prefer Ray datasets + self._prefer_ray_datasets: bool = True def _create_metadata(self) -> RetrievalMetadata: """Create metadata from the entity DataFrame and feature references.""" if self._entity_df is not None: - # Auto-detect timestamp column and get timestamp bounds using utilities timestamp_col = _safe_infer_event_timestamp_column( self._entity_df, "event_timestamp" ) @@ -1195,10 +948,8 @@ def _create_metadata(self) -> RetrievalMetadata: self._entity_df, timestamp_col ) - # Get keys (all columns except the detected timestamp column) keys = [col for col in self._entity_df.columns if col != timestamp_col] else: - # Try to extract metadata from Ray dataset if entity_df is not available try: result = self._resolve() if isinstance(result, Dataset): @@ -1258,21 +1009,14 @@ def to_df( validation_reference: Optional[ValidationReference] = None, timeout: Optional[int] = None, ) -> pd.DataFrame: - # Use cached DataFrame if available and no ODFVs if self._cached_df is not None and not self.on_demand_feature_views: df = self._cached_df else: - # If we have on-demand feature views, use the parent's implementation - # which calls to_arrow and applies the transformations if self.on_demand_feature_views: - logger.info( - f"Using parent implementation for {len(self.on_demand_feature_views)} ODFVs" - ) df = super().to_df( validation_reference=validation_reference, timeout=timeout ) else: - # For Ray datasets, prefer keeping data distributed until the final conversion if self._prefer_ray_datasets: ray_ds = self._get_ray_dataset() df = ray_ds.to_pandas() @@ -1284,13 +1028,10 @@ def to_df( df = result.to_pandas() self._cached_df = df - # Handle validation reference if provided if validation_reference: try: - # Import here to avoid circular imports from feast.dqm.errors import ValidationFailed - # Run validation using the validation reference validation_result = validation_reference.profile.validate(df) if not validation_result.is_success: raise ValidationFailed(validation_result) @@ -1306,36 +1047,29 @@ def to_arrow( validation_reference: Optional[ValidationReference] = None, timeout: Optional[int] = None, ) -> pa.Table: - # If we have ODFVs, use the parent's implementation if self.on_demand_feature_views: return super().to_arrow( validation_reference=validation_reference, timeout=timeout ) - # For Ray datasets, use direct Arrow conversion when available if self._prefer_ray_datasets: try: ray_ds = self._get_ray_dataset() - # Try to use Ray's native to_arrow() if available if hasattr(ray_ds, "to_arrow"): return ray_ds.to_arrow() else: - # Fallback to pandas conversion df = ray_ds.to_pandas() return pa.Table.from_pandas(df) except Exception: - # Fallback to pandas conversion df = self.to_df( validation_reference=validation_reference, timeout=timeout ) return pa.Table.from_pandas(df) else: - # Original implementation for non-Ray datasets result = self._resolve() if isinstance(result, pd.DataFrame): return pa.Table.from_pandas(result) else: - # For Ray Dataset, convert to pandas first then to arrow df = result.to_pandas() return pa.Table.from_pandas(df) @@ -1343,7 +1077,6 @@ def to_remote_storage(self) -> list[str]: if not self._staging_location: raise ValueError("Staging location must be set for remote materialization.") try: - # Use Ray dataset directly for remote storage ray_ds = self._get_ray_dataset() RayOfflineStore._ensure_ray_initialized() output_uri = os.path.join(self._staging_location, str(uuid.uuid4())) @@ -1371,7 +1104,6 @@ def to_sql(self) -> str: raise NotImplementedError("SQL export not supported for Ray offline store") def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: - # Use Ray dataset when possible if self._prefer_ray_datasets: ray_ds = self._get_ray_dataset() return ray_ds.to_pandas() @@ -1379,7 +1111,6 @@ def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: return self._resolve().to_pandas() def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table: - # Use Ray dataset when possible if self._prefer_ray_datasets: ray_ds = self._get_ray_dataset() try: @@ -1396,7 +1127,6 @@ def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table: if isinstance(result, pd.DataFrame): return pa.Table.from_pandas(result) else: - # For Ray Dataset, convert to pandas first then to arrow df = result.to_pandas() return pa.Table.from_pandas(df) @@ -1417,13 +1147,11 @@ def persist( if not allow_overwrite and os.path.exists(destination_path): raise SavedDatasetLocationAlreadyExists(location=destination_path) try: - # Use Ray dataset directly for persistence ray_ds = self._get_ray_dataset() if not destination_path.startswith(("s3://", "gs://", "hdfs://")): os.makedirs(os.path.dirname(destination_path), exist_ok=True) - # Use Ray's native write operations ray_ds.write_parquet(destination_path) return destination_path @@ -1436,38 +1164,18 @@ def materialize(self) -> None: ray_ds = self._get_ray_dataset() materialized_ds = ray_ds.materialize() self._cached_dataset = materialized_ds - logger.info("Ray dataset materialized successfully") + + if getattr(self._config, "enable_ray_logging", False): + logger.info("Ray dataset materialized successfully") except Exception as e: logger.warning(f"Failed to materialize Ray dataset: {e}") - def count(self) -> int: - """Get the number of rows in the dataset efficiently using Ray operations.""" - try: - ray_ds = self._get_ray_dataset() - return ray_ds.count() - except Exception: - # Fallback to pandas - df = self.to_df() - return len(df) - - def take(self, limit: int) -> pd.DataFrame: - """Take a limited number of rows efficiently using Ray operations.""" - try: - ray_ds = self._get_ray_dataset() - limited_ds = ray_ds.limit(limit) - return limited_ds.to_pandas() - except Exception: - # Fallback to pandas - df = self.to_df() - return df.head(limit) - def schema(self) -> pa.Schema: """Get the schema of the dataset efficiently using Ray operations.""" try: ray_ds = self._get_ray_dataset() return ray_ds.schema() except Exception: - # Fallback to pandas df = self.to_df() return pa.Table.from_pandas(df).schema @@ -1478,66 +1186,138 @@ def __init__(self): self._ray_initialized: bool = False self._resource_manager: Optional[RayResourceManager] = None self._data_processor: Optional[RayDataProcessor] = None - self._performance_monitoring: bool = True # Enable performance monitoring + + @staticmethod + def _suppress_ray_logging(): + """Suppress Ray and Ray Data logging completely.""" + import logging + import warnings + + # Suppress Ray warnings + warnings.filterwarnings("ignore", category=DeprecationWarning, module="ray") + warnings.filterwarnings("ignore", category=UserWarning, module="ray") + + # Set environment variables to suppress Ray output + os.environ["RAY_DISABLE_IMPORT_WARNING"] = "1" + os.environ["RAY_SUPPRESS_UNVERIFIED_TLS_WARNING"] = "1" + os.environ["RAY_LOG_LEVEL"] = "ERROR" + os.environ["RAY_DATA_LOG_LEVEL"] = "ERROR" + os.environ["RAY_DISABLE_PROGRESS_BARS"] = "1" + + # Suppress all Ray-related loggers + ray_loggers = [ + "ray", + "ray.data", + "ray.data.dataset", + "ray.data.context", + "ray.data._internal.streaming_executor", + "ray.data._internal.execution", + "ray.data._internal", + "ray.tune", + "ray.serve", + "ray.util", + "ray._private", + ] + for logger_name in ray_loggers: + logging.getLogger(logger_name).setLevel(logging.ERROR) + + # Configure DatasetContext to disable progress bars + try: + from ray.data.context import DatasetContext + + ctx = DatasetContext.get_current() + ctx.enable_progress_bars = False + if hasattr(ctx, "verbose_progress"): + ctx.verbose_progress = False + except Exception: + pass # Ignore if Ray Data is not available @staticmethod def _ensure_ray_initialized(config: Optional[RepoConfig] = None): """Ensure Ray is initialized with proper configuration.""" + ray_config = None + if config and hasattr(config, "offline_store"): + ray_config = config.offline_store + if isinstance(ray_config, RayOfflineStoreConfig): + if not ray_config.enable_ray_logging: + RayOfflineStore._suppress_ray_logging() + if not ray.is_initialized(): + ray_init_kwargs: Dict[str, Any] = { + "ignore_reinit_error": True, + "include_dashboard": False, + } + + if ( + ray_config + and isinstance(ray_config, RayOfflineStoreConfig) + and not ray_config.enable_ray_logging + ): + ray_init_kwargs.update( + { + "log_to_driver": False, + "logging_level": "ERROR", + } + ) + if config and hasattr(config, "offline_store"): - ray_config = config.offline_store if isinstance(ray_config, RayOfflineStoreConfig): if ray_config.use_ray_cluster and ray_config.ray_address: - ray.init( - address=ray_config.ray_address, - ignore_reinit_error=True, - include_dashboard=False, - ) + ray_init_kwargs["address"] = ray_config.ray_address else: - ray.init( - _node_ip_address=os.getenv("RAY_NODE_IP", "127.0.0.1"), - num_cpus=os.cpu_count() or 4, - ignore_reinit_error=True, - include_dashboard=False, + ray_init_kwargs.update( + { + "_node_ip_address": os.getenv( + "RAY_NODE_IP", "127.0.0.1" + ), + "num_cpus": os.cpu_count() or 4, + } ) + + if ray_config.ray_conf: + ray_init_kwargs.update(ray_config.ray_conf) else: - ray.init(ignore_reinit_error=True) - else: - ray.init(ignore_reinit_error=True) + pass # Use default initialization + + ray.init(**ray_init_kwargs) ctx = DatasetContext.get_current() ctx.shuffle_strategy = "sort" # type: ignore ctx.enable_tensor_extension_casting = False - # Log Ray cluster information + if ( + ray_config + and isinstance(ray_config, RayOfflineStoreConfig) + and not ray_config.enable_ray_logging + ): + RayOfflineStore._suppress_ray_logging() + if ray.is_initialized(): cluster_resources = ray.cluster_resources() - logger.info( - f"Ray cluster initialized with {cluster_resources.get('CPU', 0)} CPUs, " - f"{cluster_resources.get('memory', 0) / (1024**3):.1f}GB memory" - ) + if ( + not ray_config + or not isinstance(ray_config, RayOfflineStoreConfig) + or ray_config.enable_ray_logging + ): + logger.info( + f"Ray cluster initialized with {cluster_resources.get('CPU', 0)} CPUs, " + f"{cluster_resources.get('memory', 0) / (1024**3):.1f}GB memory" + ) def _init_ray(self, config: RepoConfig): ray_config = config.offline_store assert isinstance(ray_config, RayOfflineStoreConfig) RayOfflineStore._ensure_ray_initialized(config) + + if not ray_config.enable_ray_logging: + RayOfflineStore._suppress_ray_logging() + if self._resource_manager is None: self._resource_manager = RayResourceManager(ray_config) self._resource_manager.configure_ray_context() if self._data_processor is None: self._data_processor = RayDataProcessor(self._resource_manager) - def _log_performance_metrics( - self, operation: str, dataset_size: int, duration: float - ): - """Log performance metrics for Ray operations.""" - if self._performance_monitoring: - throughput = dataset_size / duration if duration > 0 else 0 - logger.info( - f"Ray {operation} performance: {dataset_size} rows in {duration:.2f}s " - f"({throughput:.0f} rows/s)" - ) - def _get_source_path(self, source: DataSource, config: RepoConfig) -> str: if not isinstance(source, FileSource): raise ValueError("RayOfflineStore currently only supports FileSource") @@ -1556,58 +1336,20 @@ def _optimize_dataset_for_operation(self, ds: Dataset, operation: str) -> Datase ) if requirements["can_fit_in_memory"]: - # Materialize small datasets for better performance ds = ds.materialize() - # Optimize partitioning optimal_partitions = requirements["optimal_partitions"] current_partitions = ds.num_blocks() if current_partitions != optimal_partitions: - logger.debug( - f"Repartitioning dataset from {current_partitions} to {optimal_partitions} blocks" - ) + if getattr(self._resource_manager.config, "enable_ray_logging", False): + logger.debug( + f"Repartitioning dataset from {current_partitions} to {optimal_partitions} blocks" + ) ds = ds.repartition(num_blocks=optimal_partitions) return ds - def supports_remote_storage_export(self) -> bool: - """Check if remote storage export is supported.""" - return True # Ray supports remote storage natively - - def get_feature_server_endpoint(self) -> Optional[str]: - """Get feature server endpoint if available.""" - return None # Ray offline store doesn't have a feature server endpoint - - def get_infra_object_names(self) -> List[str]: - """Get infrastructure object names managed by this store.""" - return [] # Ray offline store doesn't manage persistent infrastructure objects - - def plan_infra(self, config: RepoConfig, desired_registry_proto: Any) -> Any: - """Plan infrastructure changes.""" - # Ray offline store doesn't require infrastructure planning - return None - - def update_infra( - self, - project: str, - tables_to_delete: List[Any], - tables_to_keep: List[Any], - entities_to_delete: List[Any], - entities_to_keep: List[Any], - partial: bool, - ) -> None: - """Update infrastructure.""" - # Ray offline store doesn't require infrastructure updates - pass - - def teardown_infra( - self, project: str, tables: List[Any], entities: List[Any] - ) -> None: - """Teardown infrastructure.""" - # Ray offline store doesn't require infrastructure teardown - pass - @staticmethod def offline_write_batch( config: RepoConfig, @@ -1625,90 +1367,69 @@ def offline_write_batch( repo_path = getattr(config, "repo_path", None) or os.getcwd() ray_config = config.offline_store assert isinstance(ray_config, RayOfflineStoreConfig) + + if not ray_config.enable_ray_logging: + RayOfflineStore._suppress_ray_logging() assert isinstance(feature_view.batch_source, FileSource) - # Enhanced schema validation using safe utility validation_result = _safe_validate_schema( config, feature_view.batch_source, table.column_names, "offline_write_batch" ) if validation_result: expected_schema, expected_columns = validation_result - # Try to reorder columns to match expected order if needed if expected_columns != table.column_names and set(expected_columns) == set( table.column_names ): - logger.info("Reordering table columns to match expected schema") + if getattr(ray_config, "enable_ray_logging", False): + logger.info("Reordering table columns to match expected schema") table = table.select(expected_columns) batch_source_path = feature_view.batch_source.file_options.uri feature_path = FileSource.get_uri_for_file_path(repo_path, batch_source_path) - # Use Ray Dataset for efficient writing ds = ray.data.from_arrow(table) try: - # If the path points to a file, write directly to that file location - # If it points to a directory, write to that directory if feature_path.endswith(".parquet"): - # For single file writes, check if file exists and append if it does if os.path.exists(feature_path): - # Read existing data as Ray Dataset existing_ds = ray.data.read_parquet(feature_path) - # Append new data using Ray operations combined_ds = existing_ds.union(ds) - # Write combined data combined_ds.write_parquet(feature_path) else: - # Write new data ds.write_parquet(feature_path) else: - # Write to directory (multiple parquet files) os.makedirs(feature_path, exist_ok=True) ds.write_parquet(feature_path) - # Call progress callback if provided if progress: progress(table.num_rows) - except Exception as e: - logger.error(f"Failed to write batch data: {e}") - # Fallback to pandas-based writing - logger.info("Falling back to pandas-based writing") - - # Convert to pandas for fallback + except Exception: + if getattr(ray_config, "enable_ray_logging", False): + logger.info("Falling back to pandas-based writing") df = table.to_pandas() - if feature_path.endswith(".parquet"): - # Check if file exists and append if it does if os.path.exists(feature_path): - # Read existing data existing_df = pd.read_parquet(feature_path) - # Append new data combined_df = pd.concat([existing_df, df], ignore_index=True) - # Write combined data combined_df.to_parquet(feature_path, index=False) else: - # Write new data df.to_parquet(feature_path, index=False) else: - # Write to directory (multiple parquet files) os.makedirs(feature_path, exist_ok=True) - - # Convert to Ray dataset and write ds_fallback = ray.data.from_pandas(df) ds_fallback.write_parquet(feature_path) - # Call progress callback if provided if progress: progress(table.num_rows) - # Log performance metrics duration = time.time() - start_time - logger.info( - f"Ray offline_write_batch performance: {table.num_rows} rows in {duration:.2f}s " - f"({table.num_rows / duration:.0f} rows/s)" - ) + if getattr(ray_config, "enable_ray_logging", False): + logger.info( + f"Ray offline_write_batch performance: {table.num_rows} rows in {duration:.2f}s " + f"({table.num_rows / duration:.0f} rows/s)" + ) def online_write_batch( self, @@ -1719,84 +1440,35 @@ def online_write_batch( """Ray offline store doesn't support online writes.""" raise NotImplementedError("Ray offline store doesn't support online writes") - def get_table_query_string(self) -> str: - """Get table query string format.""" - return "file://{table_name}" - - def get_table_column_names_and_types( - self, config: RepoConfig, data_source: DataSource - ) -> Iterable[Tuple[str, str]]: - """Get table column names and types efficiently using Ray.""" - return self.get_table_column_names_and_types_from_data_source( - config, data_source + @staticmethod + def _process_filtered_batch( + batch: pd.DataFrame, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_columns: List[str], + timestamp_field_mapped: str, + ) -> pd.DataFrame: + batch = make_df_tzaware(batch) + if batch.empty: + return _handle_empty_dataframe_case( + join_key_columns, feature_name_columns, timestamp_columns + ) + all_required_columns = _build_required_columns( + join_key_columns, feature_name_columns, timestamp_columns ) - - def create_ray_dataset_from_table( - self, config: RepoConfig, data_source: DataSource - ) -> Dataset: - """Create a Ray Dataset from a data source.""" - self._init_ray(config) - source_path = self._get_source_path(data_source, config) - ds = ray.data.read_parquet(source_path) - - # Apply field mapping if needed - field_mapping = getattr(data_source, "field_mapping", None) - if field_mapping: - ds = _apply_field_mapping(ds, field_mapping) - - return ds - - def get_dataset_statistics(self, ds: Dataset) -> Dict[str, Any]: - """Get comprehensive statistics for a Ray Dataset.""" - try: - stats = { - "num_rows": ds.count(), - "num_blocks": ds.num_blocks(), - "size_bytes": ds.size_bytes(), - "schema": ds.schema(), - } - - # Add column statistics if possible - try: - column_stats = {} - for col in ds.schema().names: - try: - column_stats[col] = { - "min": ds.min(col), - "max": ds.max(col), - "mean": ds.mean(col) - if ds.schema().field(col).type - in [pa.float32(), pa.float64(), pa.int32(), pa.int64()] - else None, - } - except Exception: - # Skip columns that don't support these operations - pass - stats["column_stats"] = column_stats - except Exception: - pass - - return stats - except Exception as e: - logger.warning(f"Failed to get dataset statistics: {e}") - return {"error": str(e)} - - def validate_data_source( - self, - config: RepoConfig, - data_source: DataSource, - ): - """Validates the underlying data source.""" - self._init_ray(config) - data_source.validate(config=config) - - def get_table_column_names_and_types_from_data_source( - self, - config: RepoConfig, - data_source: DataSource, - ) -> Iterable[Tuple[str, str]]: - """Returns the list of column names and raw column types for a DataSource.""" - return data_source.get_table_column_names_and_types(config=config) + if not join_key_columns: + batch[DUMMY_ENTITY_ID] = DUMMY_ENTITY_VAL + available_columns = [ + col for col in all_required_columns if col in batch.columns + ] + batch = batch[available_columns] + if ( + "event_timestamp" not in batch.columns + and timestamp_field_mapped != "event_timestamp" + ): + if timestamp_field_mapped in batch.columns: + batch["event_timestamp"] = batch[timestamp_field_mapped] + return batch @staticmethod def _load_and_filter_dataset( @@ -1809,38 +1481,14 @@ def _load_and_filter_dataset( start_date: Optional[datetime], end_date: Optional[datetime], ) -> pd.DataFrame: - """ - Common method to load and filter dataset for both pull_latest and pull_all methods. - Args: - source_path: Path to the data source - data_source: DataSource object containing field mapping - join_key_columns: List of join key columns - feature_name_columns: List of feature columns - timestamp_field: Name of the timestamp field - created_timestamp_column: Optional created timestamp column - start_date: Optional start date for filtering - end_date: Optional end date for filtering - Returns: - Processed pandas DataFrame - """ try: - # Get field mapping for column renaming after loading field_mapping = getattr(data_source, "field_mapping", None) - - # Load and filter the dataset using the original timestamp field name ds = RayOfflineStore._create_filtered_dataset( source_path, timestamp_field, start_date, end_date ) - - # Convert to pandas for processing df = ds.to_pandas() - df = make_df_tzaware(df) - - # Apply field mapping if needed if field_mapping: df = df.rename(columns=field_mapping) - - # Get mapped field names timestamp_field_mapped = ( field_mapping.get(timestamp_field, timestamp_field) if field_mapping @@ -1851,54 +1499,24 @@ def _load_and_filter_dataset( if field_mapping and created_timestamp_column else created_timestamp_column ) - - # Build timestamp columns list timestamp_columns = [timestamp_field_mapped] if created_timestamp_column_mapped: timestamp_columns.append(created_timestamp_column_mapped) - - # Normalize timestamp columns - df = _normalize_timestamp_columns(df, timestamp_columns, inplace=True) - - # Handle empty DataFrame case - if df.empty: - return _handle_empty_dataframe_case( - join_key_columns, feature_name_columns, timestamp_columns - ) - - # Build required columns list - all_required_columns = _build_required_columns( - join_key_columns, feature_name_columns, timestamp_columns + df = normalize_timestamp_columns(df, timestamp_columns, inplace=True) + df = RayOfflineStore._process_filtered_batch( + df, + join_key_columns, + feature_name_columns, + timestamp_columns, + timestamp_field_mapped, ) - if not join_key_columns: - df[DUMMY_ENTITY_ID] = DUMMY_ENTITY_VAL - - # Select only the required columns that exist - available_columns = [ - col for col in all_required_columns if col in df.columns - ] - df = df[available_columns] - - # Basic sorting by timestamp (most recent first) existing_timestamp_columns = [ col for col in timestamp_columns if col in df.columns ] if existing_timestamp_columns: df = df.sort_values(existing_timestamp_columns, ascending=False) - - # Reset index df = df.reset_index(drop=True) - - # Ensure 'event_timestamp' column exists for pandas backend compatibility - if ( - "event_timestamp" not in df.columns - and timestamp_field_mapped != "event_timestamp" - ): - if timestamp_field_mapped in df.columns: - df["event_timestamp"] = df[timestamp_field_mapped] - return df - except Exception as e: raise RuntimeError(f"Failed to load data from {source_path}: {e}") @@ -1913,34 +1531,13 @@ def _load_and_filter_dataset_ray( start_date: Optional[datetime], end_date: Optional[datetime], ) -> Dataset: - """ - Ray-native method to load and filter dataset for distributed processing. - Args: - source_path: Path to the data source - data_source: DataSource object containing field mapping - join_key_columns: List of join key columns - feature_name_columns: List of feature columns - timestamp_field: Name of the timestamp field - created_timestamp_column: Optional created timestamp column - start_date: Optional start date for filtering - end_date: Optional end date for filtering - Returns: - Processed Ray Dataset - """ try: - # Get field mapping for column renaming after loading field_mapping = getattr(data_source, "field_mapping", None) - - # Load and filter the dataset using the original timestamp field name ds = RayOfflineStore._create_filtered_dataset( source_path, timestamp_field, start_date, end_date ) - - # Apply field mapping if needed using Ray operations if field_mapping: - ds = _apply_field_mapping(ds, field_mapping) - - # Get mapped field names + ds = apply_field_mapping(ds, field_mapping) timestamp_field_mapped = ( field_mapping.get(timestamp_field, timestamp_field) if field_mapping @@ -1951,61 +1548,33 @@ def _load_and_filter_dataset_ray( if field_mapping and created_timestamp_column else created_timestamp_column ) - - # Build timestamp columns list timestamp_columns = [timestamp_field_mapped] if created_timestamp_column_mapped: timestamp_columns.append(created_timestamp_column_mapped) - - # Normalize timestamp columns using Ray operations - ds = _normalize_timestamp_columns(ds, timestamp_columns) - - # Process dataset using Ray operations - def process_batch(batch: pd.DataFrame) -> pd.DataFrame: - # Apply timezone awareness - batch = make_df_tzaware(batch) - - # Handle empty batch case - if batch.empty: - return _handle_empty_dataframe_case( - join_key_columns, feature_name_columns, timestamp_columns - ) - - # Build required columns list - all_required_columns = _build_required_columns( - join_key_columns, feature_name_columns, timestamp_columns - ) - if not join_key_columns: - batch[DUMMY_ENTITY_ID] = DUMMY_ENTITY_VAL - - # Select only the required columns that exist - available_columns = [ - col for col in all_required_columns if col in batch.columns - ] - batch = batch[available_columns] - - # Ensure 'event_timestamp' column exists for pandas backend compatibility - if ( - "event_timestamp" not in batch.columns - and timestamp_field_mapped != "event_timestamp" - ): - if timestamp_field_mapped in batch.columns: - batch["event_timestamp"] = batch[timestamp_field_mapped] - - return batch - - ds = ds.map_batches(process_batch, batch_format="pandas") - - # Sort by timestamp (most recent first) using Ray operations + # Exclude __log_timestamp from normalization as it's used for time range filtering + exclude_columns = ( + ["__log_timestamp"] if "__log_timestamp" in timestamp_columns else [] + ) + ds = normalize_timestamp_columns( + ds, timestamp_columns, exclude_columns=exclude_columns + ) + ds = ds.map_batches( + lambda batch: RayOfflineStore._process_filtered_batch( + batch, + join_key_columns, + feature_name_columns, + timestamp_columns, + timestamp_field_mapped, + ), + batch_format="pandas", + ) timestamp_columns_existing = [ col for col in timestamp_columns if col in ds.schema().names ] if timestamp_columns_existing: - # Sort using Ray's native sorting ds = ds.sort(timestamp_columns_existing, descending=True) return ds - except Exception as e: raise RuntimeError(f"Failed to load data from {source_path}: {e}") @@ -2031,7 +1600,6 @@ def _pull_latest_processing_ray( if not join_key_columns: return ds - # Get mapped field names timestamp_field_mapped = ( field_mapping.get(timestamp_field, timestamp_field) if field_mapping @@ -2043,7 +1611,6 @@ def _pull_latest_processing_ray( else created_timestamp_column ) - # Build timestamp columns for sorting timestamp_columns = [timestamp_field_mapped] if created_timestamp_column_mapped: timestamp_columns.append(created_timestamp_column_mapped) @@ -2052,12 +1619,10 @@ def deduplicate_batch(batch: pd.DataFrame) -> pd.DataFrame: if batch.empty: return batch - # Filter out timestamp columns that don't exist in the dataframe existing_timestamp_columns = [ col for col in timestamp_columns if col in batch.columns ] - # Sort by join keys (ascending) and timestamps (descending for latest first) sort_columns = join_key_columns + existing_timestamp_columns if sort_columns: batch = batch.sort_values( @@ -2088,7 +1653,6 @@ def pull_latest_from_table_or_query( source_path = store._get_source_path(data_source, config) def _load_ray_dataset(): - # Use Ray-native processing for better performance ds = store._load_and_filter_dataset_ray( source_path, data_source, @@ -2099,8 +1663,6 @@ def _load_ray_dataset(): start_date, end_date, ) - - # Apply pull_latest processing (deduplication) using Ray operations field_mapping = getattr(data_source, "field_mapping", None) ds = store._pull_latest_processing_ray( ds, @@ -2113,7 +1675,6 @@ def _load_ray_dataset(): return ds def _load_pandas_fallback(): - # Fallback to pandas processing for compatibility return store._load_and_filter_dataset( source_path, data_source, @@ -2125,16 +1686,18 @@ def _load_pandas_fallback(): end_date, ) - # Try Ray-native processing first, fallback to pandas if needed try: return RayRetrievalJob( - _load_ray_dataset, staging_location=config.offline_store.storage_path + _load_ray_dataset, + staging_location=config.offline_store.storage_path, + config=config.offline_store, ) except Exception as e: logger.warning(f"Ray-native processing failed: {e}, falling back to pandas") return RayRetrievalJob( _load_pandas_fallback, staging_location=config.offline_store.storage_path, + config=config.offline_store, ) @staticmethod @@ -2158,7 +1721,6 @@ def pull_all_from_table_or_query( raise FileNotFoundError(f"Parquet path does not exist: {source_path}") def _load_ray_dataset(): - # Use Ray-native processing for better performance return store._load_and_filter_dataset_ray( source_path, data_source, @@ -2171,7 +1733,6 @@ def _load_ray_dataset(): ) def _load_pandas_fallback(): - # Fallback to pandas processing for compatibility return store._load_and_filter_dataset( source_path, data_source, @@ -2183,16 +1744,18 @@ def _load_pandas_fallback(): end_date, ) - # Try Ray-native processing first, fallback to pandas if needed try: return RayRetrievalJob( - _load_ray_dataset, staging_location=config.offline_store.storage_path + _load_ray_dataset, + staging_location=config.offline_store.storage_path, + config=config.offline_store, ) except Exception as e: logger.warning(f"Ray-native processing failed: {e}, falling back to pandas") return RayRetrievalJob( _load_pandas_fallback, staging_location=config.offline_store.storage_path, + config=config.offline_store, ) @staticmethod @@ -2205,31 +1768,52 @@ def write_logged_features( ) -> None: RayOfflineStore._ensure_ray_initialized(config) - repo_path = getattr(config, "repo_path", None) or os.getcwd() - - # Get source path and resolve URI - source_path = getattr(source, "file_path", None) - if not source_path: - raise ValueError("LoggingSource must have a file_path attribute") + ray_config = getattr(config, "offline_store", None) + if ( + ray_config + and isinstance(ray_config, RayOfflineStoreConfig) + and not ray_config.enable_ray_logging + ): + RayOfflineStore._suppress_ray_logging() + + destination = logging_config.destination + assert isinstance(destination, FileLoggingDestination), ( + f"Ray offline store only supports FileLoggingDestination for logging, " + f"got {type(destination)}" + ) - path = FileSource.get_uri_for_file_path(repo_path, source_path) + repo_path = getattr(config, "repo_path", None) or os.getcwd() + absolute_path = FileSource.get_uri_for_file_path(repo_path, destination.path) try: - # Use Ray dataset for efficient writing if isinstance(data, Path): ds = ray.data.read_parquet(str(data)) else: - # Convert PyArrow Table to Ray Dataset directly ds = ray.data.from_arrow(data) - # Materialize for better performance - ds = ds.materialize() - - if not path.startswith(("s3://", "gs://")): - os.makedirs(os.path.dirname(path), exist_ok=True) + # Normalize feature timestamp precision to seconds to match test expectations during write + # Note: Don't normalize __log_timestamp as it's used for time range filtering + def normalize_timestamps(batch: pd.DataFrame) -> pd.DataFrame: + batch = batch.copy() + for col in batch.columns: + if ( + pd.api.types.is_datetime64_any_dtype(batch[col]) + and col != "__log_timestamp" + ): + batch[col] = batch[col].dt.floor("s") + return batch - # Use Ray's native write operations - ds.write_parquet(path) + ds = ds.map_batches(normalize_timestamps, batch_format="pandas") + ds = ds.materialize() + filesystem, resolved_path = FileSource.create_filesystem_and_path( + absolute_path, destination.s3_endpoint_override + ) + path_obj = Path(resolved_path) + if path_obj.suffix == ".parquet": + path_obj = path_obj.with_suffix("") + if not absolute_path.startswith(("s3://", "gs://")): + path_obj.mkdir(parents=True, exist_ok=True) + ds.write_parquet(str(path_obj)) except Exception as e: raise RuntimeError(f"Failed to write logged features: {e}") @@ -2331,21 +1915,15 @@ def get_historical_features( # Load entity_df as Ray dataset for distributed processing if isinstance(entity_df, str): entity_ds = ray.data.read_csv(entity_df) - # Keep a minimal pandas copy only for metadata creation entity_df_sample = entity_ds.limit(1000).to_pandas() else: entity_ds = ray.data.from_pandas(entity_df) entity_df_sample = entity_df.copy() - # Make entity dataset timezone aware and normalize timestamp using Ray operations - entity_ds = _ensure_timestamp_compatibility(entity_ds, ["event_timestamp"]) - - # Parse feature_refs and get ODFVs + entity_ds = ensure_timestamp_compatibility(entity_ds, ["event_timestamp"]) on_demand_feature_views = OnDemandFeatureView.get_requested_odfvs( feature_refs, project, registry ) - - # Validate request data for ODFVs using sample for odfv in on_demand_feature_views: odfv_request_data_schema = odfv.get_request_data_schema() for feature_name in odfv_request_data_schema.keys(): @@ -2355,19 +1933,10 @@ def get_historical_features( feature_view_name=odfv.name, ) - # Filter out on-demand feature views from regular feature views - # ODFVs don't have data sources and are computed from base features odfv_names = {odfv.name for odfv in on_demand_feature_views} regular_feature_views = [ fv for fv in feature_views if fv.name not in odfv_names ] - - # Enhanced validation using unified operations - _safe_validate_entity_dataframe( - entity_ds, regular_feature_views, project, registry - ) - - # Apply field mappings to entity dataset if needed using unified operations global_field_mappings = {} for fv in regular_feature_views: mapping = getattr(fv.batch_source, "field_mapping", None) @@ -2382,12 +1951,9 @@ def get_historical_features( if v in entity_df_sample.columns } if cols_to_rename: - entity_ds = _apply_field_mapping(entity_ds, cols_to_rename) + entity_ds = apply_field_mapping(entity_ds, cols_to_rename) - # Start with entity dataset - keep it as Ray dataset throughout result_ds = entity_ds - - # Process each regular feature view with intelligent join strategy for fv in regular_feature_views: fv_feature_refs = [ ref @@ -2397,14 +1963,12 @@ def get_historical_features( if not fv_feature_refs: continue - # Get join configuration entities = fv.entities or [] entity_objs = [registry.get_entity(e, project) for e in entities] original_join_keys, _, timestamp_field, created_col = _get_column_names( fv, entity_objs ) - # Apply join key mapping from projection if present if fv.projection.join_key_map: join_keys = [ fv.projection.join_key_map.get(key, key) @@ -2413,10 +1977,8 @@ def get_historical_features( else: join_keys = original_join_keys - # Extract requested features requested_feats = [ref.split(":", 1)[1] for ref in fv_feature_refs] - # Validate requested features exist available_feature_names = [f.name for f in fv.features] missing_feats = [ f for f in requested_feats if f not in available_feature_names @@ -2427,22 +1989,18 @@ def get_historical_features( f"(available: {available_feature_names})" ) - # Load feature data as Ray dataset source_path = store._get_source_path(fv.batch_source, config) feature_ds = ray.data.read_parquet(source_path) feature_size = feature_ds.size_bytes() - # Apply field mapping to feature dataset if needed using unified operations field_mapping = getattr(fv.batch_source, "field_mapping", None) if field_mapping: - feature_ds = _apply_field_mapping(feature_ds, field_mapping) - # Update join keys and timestamp field to mapped names + feature_ds = apply_field_mapping(feature_ds, field_mapping) join_keys = [field_mapping.get(k, k) for k in join_keys] timestamp_field = field_mapping.get(timestamp_field, timestamp_field) if created_col: created_col = field_mapping.get(created_col, created_col) - # Ensure timestamp compatibility in entity dataset using unified operations if ( timestamp_field != "event_timestamp" and timestamp_field not in entity_df_sample.columns @@ -2457,9 +2015,8 @@ def add_timestamp_field(batch: pd.DataFrame) -> pd.DataFrame: result_ds = result_ds.map_batches( add_timestamp_field, batch_format="pandas" ) - result_ds = _normalize_timestamp_columns(result_ds, timestamp_field) + result_ds = normalize_timestamp_columns(result_ds, timestamp_field) - # Determine join strategy based on dataset sizes and cluster resources if store._resource_manager is None: raise ValueError("Resource manager not initialized") requirements = store._resource_manager.estimate_processing_requirements( @@ -2468,12 +2025,12 @@ def add_timestamp_field(batch: pd.DataFrame) -> pd.DataFrame: if requirements["should_broadcast"]: # Use broadcast join for small feature datasets - logger.info( - f"Using broadcast join for {fv.name} (size: {feature_size // 1024**2}MB)" - ) - # Convert to pandas only for broadcast join + if getattr(store._resource_manager.config, "enable_ray_logging", False): + logger.info( + f"Using broadcast join for {fv.name} (size: {feature_size // 1024**2}MB)" + ) feature_df = feature_ds.to_pandas() - feature_df = _ensure_timestamp_compatibility( + feature_df = ensure_timestamp_compatibility( feature_df, [timestamp_field] ) @@ -2491,12 +2048,11 @@ def add_timestamp_field(batch: pd.DataFrame) -> pd.DataFrame: ) else: # Use distributed windowed join for large feature datasets - logger.info( - f"Using distributed join for {fv.name} (size: {feature_size // 1024**2}MB)" - ) - - # Ensure timestamp format in feature dataset using unified operations - feature_ds = _ensure_timestamp_compatibility( + if getattr(store._resource_manager.config, "enable_ray_logging", False): + logger.info( + f"Using distributed join for {fv.name} (size: {feature_size // 1024**2}MB)" + ) + feature_ds = ensure_timestamp_compatibility( feature_ds, [timestamp_field] ) @@ -2516,34 +2072,27 @@ def add_timestamp_field(batch: pd.DataFrame) -> pd.DataFrame: else None, ) - # Final processing: clean up and ensure proper column structure using Ray operations def finalize_result(batch: pd.DataFrame) -> pd.DataFrame: batch = batch.copy() - # Preserve existing feature columns (including renamed ones) existing_columns = set(batch.columns) - - # Re-attach any missing original entity columns that aren't already present for col in entity_df_sample.columns: if col not in existing_columns: - # For missing columns, use values from entity df sample if len(batch) <= len(entity_df_sample): batch[col] = entity_df_sample[col].iloc[: len(batch)].values else: - # Repeat values if batch is larger repeated_values = np.tile( entity_df_sample[col].values, (len(batch) // len(entity_df_sample) + 1), ) batch[col] = repeated_values[: len(batch)] - # Ensure event_timestamp is present if "event_timestamp" not in batch.columns: if "event_timestamp" in entity_df_sample.columns: batch["event_timestamp"] = ( entity_df_sample["event_timestamp"].iloc[: len(batch)].values ) - batch = _normalize_timestamp_columns( + batch = normalize_timestamp_columns( batch, "event_timestamp", inplace=True ) elif timestamp_field in batch.columns: @@ -2552,20 +2101,18 @@ def finalize_result(batch: pd.DataFrame) -> pd.DataFrame: return batch result_ds = result_ds.map_batches(finalize_result, batch_format="pandas") - - # Apply feature type conversion using unified operations result_ds = _convert_feature_column_types(result_ds, regular_feature_views) - # Storage path validation storage_path = config.offline_store.storage_path if not storage_path: raise ValueError("Storage path must be set in config") - # Create retrieval job following standard pattern - job = RayRetrievalJob(result_ds, staging_location=storage_path) + job = RayRetrievalJob( + result_ds, staging_location=storage_path, config=config.offline_store + ) job._full_feature_names = full_feature_names job._on_demand_feature_views = on_demand_feature_views job._feature_refs = feature_refs - job._entity_df = entity_df_sample # Use sample for metadata creation + job._entity_df = entity_df_sample job._metadata = job._create_metadata() return job diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_repo_configuration.py b/sdk/python/feast/infra/offline_stores/contrib/ray_repo_configuration.py index 43628e7ea1a..32c0d6dbabd 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ray_repo_configuration.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_repo_configuration.py @@ -31,6 +31,15 @@ def __init__(self, project_name: str, *args, **kwargs): storage_path="/tmp/ray-storage", ray_address=None, use_ray_cluster=False, + broadcast_join_threshold_mb=25, + max_parallelism_multiplier=1, + target_partition_size_mb=16, + enable_ray_logging=False, + ray_conf={ + "num_cpus": 1, + "object_store_memory": 80 * 1024 * 1024, + "_memory": 400 * 1024 * 1024, + }, ) self.files: list[Any] = [] self.dirs: list[str] = [] @@ -102,6 +111,14 @@ def get_saved_dataset_data_source(self) -> Dict[str, str]: "path": "data/saved_dataset.parquet", } + @staticmethod + def xdist_groups() -> list[str]: + """ + Return xdist group names for Ray tests. + This ensures all Ray tests run on the same pytest worker to avoid OOM issues. + """ + return ["ray"] + # Define the full repo configurations for Ray offline store FULL_REPO_CONFIGS = [ diff --git a/sdk/python/feast/infra/ray_shared_utils.py b/sdk/python/feast/infra/ray_shared_utils.py new file mode 100644 index 00000000000..df8dfeb9fdb --- /dev/null +++ b/sdk/python/feast/infra/ray_shared_utils.py @@ -0,0 +1,363 @@ +from typing import Dict, List, Optional, Union + +import numpy as np +import pandas as pd +from ray.data import Dataset + + +def normalize_timestamp_columns( + data: Union[pd.DataFrame, Dataset], + columns: Union[str, List[str]], + inplace: bool = False, + exclude_columns: Optional[List[str]] = None, +) -> Union[pd.DataFrame, Dataset]: + column_list = [columns] if isinstance(columns, str) else columns + exclude_columns = exclude_columns or [] + + def apply_normalization(series: pd.Series) -> pd.Series: + return ( + pd.to_datetime(series, utc=True, errors="coerce") + .dt.floor("s") + .astype("datetime64[ns, UTC]") + ) + + if isinstance(data, Dataset): + + def normalize_batch(batch: pd.DataFrame) -> pd.DataFrame: + for column in column_list: + if ( + not batch.empty + and column in batch.columns + and column not in exclude_columns + ): + batch[column] = apply_normalization(batch[column]) + return batch + + return data.map_batches(normalize_batch, batch_format="pandas") + else: + if not inplace: + data = data.copy() + for column in column_list: + if column in data.columns and column not in exclude_columns: + data[column] = apply_normalization(data[column]) + return data + + +def ensure_timestamp_compatibility( + data: Union[pd.DataFrame, Dataset], + timestamp_fields: List[str], + inplace: bool = False, +) -> Union[pd.DataFrame, Dataset]: + from feast.utils import make_df_tzaware + + if isinstance(data, Dataset): + + def ensure_compatibility(batch: pd.DataFrame) -> pd.DataFrame: + batch = make_df_tzaware(batch) + for field in timestamp_fields: + if field in batch.columns: + batch[field] = ( + pd.to_datetime(batch[field], utc=True, errors="coerce") + .dt.floor("s") + .astype("datetime64[ns, UTC]") + ) + return batch + + return data.map_batches(ensure_compatibility, batch_format="pandas") + else: + if not inplace: + data = data.copy() + from feast.utils import make_df_tzaware + + data = make_df_tzaware(data) + for field in timestamp_fields: + if field in data.columns: + data = normalize_timestamp_columns(data, field, inplace=True) + return data + + +def apply_field_mapping( + data: Union[pd.DataFrame, Dataset], field_mapping: Dict[str, str] +) -> Union[pd.DataFrame, Dataset]: + def rename_columns(df: pd.DataFrame) -> pd.DataFrame: + return df.rename(columns=field_mapping) + + if isinstance(data, Dataset): + return data.map_batches(rename_columns, batch_format="pandas") + else: + return data.rename(columns=field_mapping) + + +def deduplicate_by_keys_and_timestamp( + data: Union[pd.DataFrame, Dataset], + join_keys: List[str], + timestamp_columns: List[str], +) -> Union[pd.DataFrame, Dataset]: + def deduplicate_batch(batch: pd.DataFrame) -> pd.DataFrame: + if batch.empty: + return batch + sort_columns = join_keys + timestamp_columns + available_columns = [col for col in sort_columns if col in batch.columns] + if available_columns: + sorted_batch = batch.sort_values( + available_columns, + ascending=[True] * len(join_keys) + [False] * len(timestamp_columns), + ) + deduped_batch = sorted_batch.drop_duplicates( + subset=join_keys, + keep="first", + ) + return deduped_batch + return batch + + if isinstance(data, Dataset): + return data.map_batches(deduplicate_batch, batch_format="pandas") + else: + return deduplicate_batch(data) + + +def broadcast_join( + entity_ds: Dataset, + feature_df: pd.DataFrame, + join_keys: List[str], + timestamp_field: str, + requested_feats: List[str], + full_feature_names: bool = False, + feature_view_name: Optional[str] = None, + original_join_keys: Optional[List[str]] = None, +) -> Dataset: + import ray + + def join_batch_with_features(batch: pd.DataFrame) -> pd.DataFrame: + features = ray.get(feature_ref) + if original_join_keys: + feature_join_keys = original_join_keys + entity_join_keys = join_keys + else: + feature_join_keys = join_keys + entity_join_keys = join_keys + feature_cols = [timestamp_field] + feature_join_keys + requested_feats + available_feature_cols = [ + col for col in feature_cols if col in features.columns + ] + features_filtered = features[available_feature_cols].copy() + from .ray_shared_utils import normalize_timestamp_columns + + batch = normalize_timestamp_columns(batch, timestamp_field, inplace=True) + features_filtered = normalize_timestamp_columns( + features_filtered, timestamp_field, inplace=True + ) + if not entity_join_keys: + batch_sorted = batch.sort_values(timestamp_field).reset_index(drop=True) + features_sorted = features_filtered.sort_values( + timestamp_field + ).reset_index(drop=True) + result = pd.merge_asof( + batch_sorted, + features_sorted, + on=timestamp_field, + direction="backward", + ) + else: + for key in entity_join_keys: + if key not in batch.columns: + batch[key] = np.nan + for key in feature_join_keys: + if key not in features_filtered.columns: + features_filtered[key] = np.nan + batch_clean = batch.dropna( + subset=entity_join_keys + [timestamp_field] + ).copy() + features_clean = features_filtered.dropna( + subset=feature_join_keys + [timestamp_field] + ).copy() + if batch_clean.empty or features_clean.empty: + return batch.head(0) + if timestamp_field in batch_clean.columns: + batch_sorted = batch_clean.sort_values( + timestamp_field, ascending=True + ).reset_index(drop=True) + else: + batch_sorted = batch_clean.reset_index(drop=True) + right_sort_columns = [ + k for k in feature_join_keys if k in features_clean.columns + ] + if timestamp_field in features_clean.columns: + right_sort_columns.append(timestamp_field) + if right_sort_columns: + features_clean = features_clean.drop_duplicates( + subset=right_sort_columns, keep="last" + ) + features_sorted = features_clean.sort_values( + right_sort_columns, ascending=True + ).reset_index(drop=True) + else: + features_sorted = features_clean.reset_index(drop=True) + try: + if feature_join_keys: + batch_dedup_cols = [ + k for k in entity_join_keys if k in batch_sorted.columns + ] + if timestamp_field in batch_sorted.columns: + batch_dedup_cols.append(timestamp_field) + if batch_dedup_cols: + batch_sorted = batch_sorted.drop_duplicates( + subset=batch_dedup_cols, keep="last" + ) + feature_dedup_cols = [ + k for k in feature_join_keys if k in features_sorted.columns + ] + if timestamp_field in features_sorted.columns: + feature_dedup_cols.append(timestamp_field) + if feature_dedup_cols: + features_sorted = features_sorted.drop_duplicates( + subset=feature_dedup_cols, keep="last" + ) + if feature_join_keys: + if entity_join_keys == feature_join_keys: + result = pd.merge_asof( + batch_sorted, + features_sorted, + on=timestamp_field, + by=entity_join_keys, + direction="backward", + suffixes=("", "_right"), + ) + else: + result = pd.merge_asof( + batch_sorted, + features_sorted, + on=timestamp_field, + left_by=entity_join_keys, + right_by=feature_join_keys, + direction="backward", + suffixes=("", "_right"), + ) + else: + result = pd.merge_asof( + batch_sorted, + features_sorted, + on=timestamp_field, + direction="backward", + suffixes=("", "_right"), + ) + except Exception: + # fallback to manual join if needed + result = batch_clean # fallback logic can be expanded + if full_feature_names and feature_view_name: + for feat in requested_feats: + if feat in result.columns: + new_name = f"{feature_view_name}__{feat}" + result[new_name] = result[feat] + result = result.drop(columns=[feat]) + return result + + feature_ref = ray.put(feature_df) + return entity_ds.map_batches(join_batch_with_features, batch_format="pandas") + + +def distributed_windowed_join( + entity_ds: Dataset, + feature_ds: Dataset, + join_keys: List[str], + timestamp_field: str, + requested_feats: List[str], + window_size: Optional[str] = None, + full_feature_names: bool = False, + feature_view_name: Optional[str] = None, + original_join_keys: Optional[List[str]] = None, +) -> Dataset: + import pandas as pd + + def add_window_and_source(ds, timestamp_field, source_marker, window_size): + def add_window_and_source_batch(batch: pd.DataFrame) -> pd.DataFrame: + batch = batch.copy() + if timestamp_field in batch.columns: + batch["time_window"] = ( + pd.to_datetime(batch[timestamp_field]) + .dt.floor(window_size) + .astype("datetime64[ns, UTC]") + ) + batch["_data_source"] = source_marker + return batch + + return ds.map_batches(add_window_and_source_batch, batch_format="pandas") + + entity_windowed = add_window_and_source( + entity_ds, timestamp_field, "entity", window_size or "1H" + ) + feature_windowed = add_window_and_source( + feature_ds, timestamp_field, "feature", window_size or "1H" + ) + combined_ds = entity_windowed.union(feature_windowed) + + def windowed_point_in_time_logic(batch: pd.DataFrame) -> pd.DataFrame: + if len(batch) == 0: + return pd.DataFrame() + result_chunks = [] + group_keys = ["time_window"] + join_keys + for group_values, group_data in batch.groupby(group_keys): + entity_data = group_data[group_data["_data_source"] == "entity"].copy() + feature_data = group_data[group_data["_data_source"] == "feature"].copy() + if len(entity_data) > 0 and len(feature_data) > 0: + entity_clean = entity_data.drop(columns=["time_window", "_data_source"]) + feature_clean = feature_data.drop( + columns=["time_window", "_data_source"] + ) + if join_keys: + merged = pd.merge_asof( + entity_clean.sort_values(join_keys + [timestamp_field]), + feature_clean.sort_values(join_keys + [timestamp_field]), + on=timestamp_field, + by=join_keys, + direction="backward", + ) + else: + merged = pd.merge_asof( + entity_clean.sort_values(timestamp_field), + feature_clean.sort_values(timestamp_field), + on=timestamp_field, + direction="backward", + ) + result_chunks.append(merged) + elif len(entity_data) > 0: + entity_clean = entity_data.drop(columns=["time_window", "_data_source"]) + for feat in requested_feats: + if feat not in entity_clean.columns: + entity_clean[feat] = np.nan + result_chunks.append(entity_clean) + if result_chunks: + result = pd.concat(result_chunks, ignore_index=True) + if full_feature_names and feature_view_name: + for feat in requested_feats: + if feat in result.columns: + new_name = f"{feature_view_name}__{feat}" + result[new_name] = result[feat] + result = result.drop(columns=[feat]) + return result + else: + return pd.DataFrame() + + return combined_ds.map_batches(windowed_point_in_time_logic, batch_format="pandas") + + +def _build_required_columns( + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_columns: List[str], +) -> List[str]: + """ + Build list of required columns for data processing. + Args: + join_key_columns: List of join key columns + feature_name_columns: List of feature columns + timestamp_columns: List of timestamp columns + Returns: + List of all required columns + """ + all_required_columns = join_key_columns + feature_name_columns + timestamp_columns + if not join_key_columns: + all_required_columns.append("__DUMMY_ENTITY_ID__") + if "event_timestamp" not in all_required_columns: + all_required_columns.append("event_timestamp") + return all_required_columns diff --git a/sdk/python/feast/infra/registry/caching_registry.py b/sdk/python/feast/infra/registry/caching_registry.py index 23ab80ee1d8..ab7944585b6 100644 --- a/sdk/python/feast/infra/registry/caching_registry.py +++ b/sdk/python/feast/infra/registry/caching_registry.py @@ -426,19 +426,19 @@ def list_projects( def refresh(self, project: Optional[str] = None): if self._refresh_lock.locked(): - logger.info("Skipping refresh if already in progress") + logger.debug("Skipping refresh if already in progress") return try: self.cached_registry_proto = self.proto() self.cached_registry_proto_created = _utc_now() except Exception as e: - logger.error(f"Error while refreshing registry: {e}", exc_info=True) + logger.debug(f"Error while refreshing registry: {e}", exc_info=True) def _refresh_cached_registry_if_necessary(self): if self.cache_mode == "sync": # Try acquiring the lock without blocking if not self._refresh_lock.acquire(blocking=False): - logger.info( + logger.debug( "Skipping refresh if lock is already held by another thread" ) return @@ -464,10 +464,10 @@ def _refresh_cached_registry_if_necessary(self): ) ) if expired: - logger.info("Registry cache expired, so refreshing") + logger.debug("Registry cache expired, so refreshing") self.refresh() except Exception as e: - logger.error( + logger.debug( f"Error in _refresh_cached_registry_if_necessary: {e}", exc_info=True, ) diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 24c9d30a028..948410c8861 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -49,6 +49,7 @@ "lambda": "feast.infra.compute_engines.aws_lambda.lambda_engine.LambdaComputeEngine", "k8s": "feast.infra.compute_engines.kubernetes.k8s_engine.KubernetesComputeEngine", "spark.engine": "feast.infra.compute_engines.spark.compute.SparkComputeEngine", + "ray.engine": "feast.infra.compute_engines.ray.compute.RayComputeEngine", } LEGACY_ONLINE_STORE_CLASS_FOR_TYPE = { diff --git a/sdk/python/feast/transformation/pandas_transformation.py b/sdk/python/feast/transformation/pandas_transformation.py index 469ddaa7768..6e073c30100 100644 --- a/sdk/python/feast/transformation/pandas_transformation.py +++ b/sdk/python/feast/transformation/pandas_transformation.py @@ -19,29 +19,43 @@ class PandasTransformation(Transformation): def __new__( cls, - udf: Callable[[Any], Any], - udf_string: str, + udf: Optional[Callable[[Any], Any]] = None, + udf_string: Optional[str] = None, name: Optional[str] = None, tags: Optional[dict[str, str]] = None, description: str = "", owner: str = "", ) -> "PandasTransformation": - instance = super(PandasTransformation, cls).__new__( - cls, - mode=TransformationMode.PANDAS, - udf=udf, - name=name, - udf_string=udf_string, - tags=tags, - description=description, - owner=owner, + # Handle Ray deserialization where parameters may not be provided + if udf is None and udf_string is None: + # Create a bare instance for deserialization + instance = object.__new__(cls) + return cast("PandasTransformation", instance) + + # Ensure required parameters are not None before calling parent constructor + if udf is None: + raise ValueError("udf parameter cannot be None") + if udf_string is None: + raise ValueError("udf_string parameter cannot be None") + + return cast( + "PandasTransformation", + super(PandasTransformation, cls).__new__( + cls, + mode=TransformationMode.PANDAS, + udf=udf, + name=name, + udf_string=udf_string, + tags=tags, + description=description, + owner=owner, + ), ) - return cast(PandasTransformation, instance) def __init__( self, - udf: Callable[[Any], Any], - udf_string: str, + udf: Optional[Callable[[Any], Any]] = None, + udf_string: Optional[str] = None, name: Optional[str] = None, tags: Optional[dict[str, str]] = None, description: str = "", @@ -49,6 +63,17 @@ def __init__( *args, **kwargs, ): + # Handle Ray deserialization where parameters may not be provided + if udf is None and udf_string is None: + # Early return for deserialization - don't initialize + return + + # Ensure required parameters are not None before calling parent constructor + if udf is None: + raise ValueError("udf parameter cannot be None") + if udf_string is None: + raise ValueError("udf_string parameter cannot be None") + return_annotation = get_type_hints(udf).get("return", inspect._empty) if return_annotation not in (inspect._empty, pd.DataFrame): raise TypeError( diff --git a/sdk/python/tests/doctest/test_all.py b/sdk/python/tests/doctest/test_all.py index de032264e6d..8a85a72ab45 100644 --- a/sdk/python/tests/doctest/test_all.py +++ b/sdk/python/tests/doctest/test_all.py @@ -71,48 +71,53 @@ def test_docstrings(): next_packages = [] for package in current_packages: - for _, name, is_pkg in pkgutil.walk_packages(package.__path__): - if name in FILES_TO_IGNORE: - continue - - full_name = package.__name__ + "." + name - try: - # https://github.com/feast-dev/feast/issues/5088 - if "ikv" not in full_name and "milvus" not in full_name: - temp_module = importlib.import_module(full_name) - if is_pkg: - next_packages.append(temp_module) - except ModuleNotFoundError: - pass - - # Retrieve the setup and teardown functions defined in this file. - relative_path_from_feast = full_name.split(".", 1)[1] - function_suffix = relative_path_from_feast.replace(".", "_") - setup_function_name = "setup_" + function_suffix - teardown_function_name = "teardown_" + function_suffix - setup_function = globals().get(setup_function_name) - teardown_function = globals().get(teardown_function_name) - - # Execute the test with setup and teardown functions. - try: - if setup_function: - setup_function() - - test_suite = doctest.DocTestSuite( - temp_module, - optionflags=doctest.ELLIPSIS, - ) - if test_suite.countTestCases() > 0: - result = unittest.TextTestRunner(sys.stdout).run(test_suite) - if not result.wasSuccessful(): - successful = False - failed_cases.append(result.failures) - except Exception as e: - successful = False - failed_cases.append((full_name, str(e) + traceback.format_exc())) - finally: - if teardown_function: - teardown_function() + try: + for _, name, is_pkg in pkgutil.walk_packages(package.__path__): + if name in FILES_TO_IGNORE: + continue + + full_name = package.__name__ + "." + name + try: + # https://github.com/feast-dev/feast/issues/5088 + if "ikv" not in full_name and "milvus" not in full_name: + temp_module = importlib.import_module(full_name) + if is_pkg: + next_packages.append(temp_module) + except ModuleNotFoundError: + pass + + # Retrieve the setup and teardown functions defined in this file. + relative_path_from_feast = full_name.split(".", 1)[1] + function_suffix = relative_path_from_feast.replace(".", "_") + setup_function_name = "setup_" + function_suffix + teardown_function_name = "teardown_" + function_suffix + setup_function = globals().get(setup_function_name) + teardown_function = globals().get(teardown_function_name) + + # Execute the test with setup and teardown functions. + try: + if setup_function: + setup_function() + + test_suite = doctest.DocTestSuite( + temp_module, + optionflags=doctest.ELLIPSIS, + ) + if test_suite.countTestCases() > 0: + result = unittest.TextTestRunner(sys.stdout).run(test_suite) + if not result.wasSuccessful(): + successful = False + failed_cases.append(result.failures) + except Exception as e: + successful = False + failed_cases.append( + (full_name, str(e) + traceback.format_exc()) + ) + finally: + if teardown_function: + teardown_function() + except DeprecationWarning: # To catch ray.tune.automl deprecation + pass current_packages = next_packages diff --git a/sdk/python/tests/integration/__init__.py b/sdk/python/tests/integration/__init__.py new file mode 100644 index 00000000000..c66cd71b7e1 --- /dev/null +++ b/sdk/python/tests/integration/__init__.py @@ -0,0 +1 @@ +"""Integration tests package.""" diff --git a/sdk/python/tests/integration/compute_engines/__init__.py b/sdk/python/tests/integration/compute_engines/__init__.py new file mode 100644 index 00000000000..6a582448b68 --- /dev/null +++ b/sdk/python/tests/integration/compute_engines/__init__.py @@ -0,0 +1 @@ +"""Compute engines integration tests package.""" diff --git a/sdk/python/tests/integration/compute_engines/ray_compute/__init__.py b/sdk/python/tests/integration/compute_engines/ray_compute/__init__.py new file mode 100644 index 00000000000..7938db59420 --- /dev/null +++ b/sdk/python/tests/integration/compute_engines/ray_compute/__init__.py @@ -0,0 +1 @@ +"""Ray compute engine integration tests.""" diff --git a/sdk/python/tests/integration/compute_engines/ray_compute/repo_configuration.py b/sdk/python/tests/integration/compute_engines/ray_compute/repo_configuration.py new file mode 100644 index 00000000000..9321ad8d6b7 --- /dev/null +++ b/sdk/python/tests/integration/compute_engines/ray_compute/repo_configuration.py @@ -0,0 +1,72 @@ +"""Test configuration for Ray compute engine integration tests.""" + +from feast.infra.offline_stores.contrib.ray_repo_configuration import ( + RayDataSourceCreator, +) +from tests.integration.feature_repos.integration_test_repo_config import ( + IntegrationTestRepoConfig, +) +from tests.integration.feature_repos.universal.online_store.redis import ( + RedisOnlineStoreCreator, +) + + +def get_ray_compute_engine_test_config() -> IntegrationTestRepoConfig: + """Get test configuration for Ray compute engine.""" + return IntegrationTestRepoConfig( + provider="local", + online_store_creator=RedisOnlineStoreCreator, + offline_store_creator=RayDataSourceCreator, + batch_engine={ + "type": "ray.engine", + "use_ray_cluster": False, + "max_workers": 1, + "enable_optimization": True, + "broadcast_join_threshold_mb": 25, + "target_partition_size_mb": 16, + "window_size_for_joins": "1H", + "ray_conf": { + "num_cpus": 1, + "object_store_memory": 80 * 1024 * 1024, + "_memory": 400 * 1024 * 1024, + }, + }, + ) + + +# Configuration for different test scenarios +COMPUTE_ENGINE_CONFIGS = { + "local": { + "type": "ray.engine", + "use_ray_cluster": False, + "max_workers": 1, + "enable_optimization": True, + "ray_conf": { + "num_cpus": 1, + "object_store_memory": 80 * 1024 * 1024, + "_memory": 400 * 1024 * 1024, + }, + }, + "cluster": { + "type": "ray.engine", + "use_ray_cluster": True, + "ray_address": "ray://localhost:10001", + "max_workers": 2, + "enable_optimization": True, + }, + "optimized": { + "type": "ray.engine", + "use_ray_cluster": False, + "max_workers": 2, + "enable_optimization": True, + "broadcast_join_threshold_mb": 25, + "enable_distributed_joins": True, + "max_parallelism_multiplier": 1, + "target_partition_size_mb": 16, + "ray_conf": { + "num_cpus": 1, + "object_store_memory": 80 * 1024 * 1024, + "_memory": 400 * 1024 * 1024, + }, + }, +} diff --git a/sdk/python/tests/integration/compute_engines/ray_compute/test_compute.py b/sdk/python/tests/integration/compute_engines/ray_compute/test_compute.py new file mode 100644 index 00000000000..8f42dad7d57 --- /dev/null +++ b/sdk/python/tests/integration/compute_engines/ray_compute/test_compute.py @@ -0,0 +1,291 @@ +from datetime import datetime, timedelta +from typing import cast +from unittest.mock import MagicMock + +import pandas as pd +import pytest +import ray +from tqdm import tqdm + +from feast import BatchFeatureView, Entity, Field +from feast.aggregation import Aggregation +from feast.data_source import DataSource +from feast.infra.common.materialization_job import ( + MaterializationJobStatus, + MaterializationTask, +) +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.ray.compute import RayComputeEngine +from feast.infra.compute_engines.ray.config import RayComputeEngineConfig +from feast.infra.compute_engines.ray.job import RayDAGRetrievalJob +from feast.infra.offline_stores.contrib.ray_offline_store.ray import ( + RayOfflineStore, +) +from feast.infra.offline_stores.contrib.ray_repo_configuration import ( + RayDataSourceCreator, +) +from feast.types import Float32, Int32, Int64 +from tests.integration.feature_repos.integration_test_repo_config import ( + IntegrationTestRepoConfig, +) +from tests.integration.feature_repos.repo_configuration import ( + construct_test_environment, +) +from tests.integration.feature_repos.universal.online_store.redis import ( + RedisOnlineStoreCreator, +) + +now = datetime.now() +today = datetime.today() + +driver = Entity( + name="driver_id", + description="driver id", +) + + +def create_feature_dataset(ray_environment) -> DataSource: + yesterday = today - timedelta(days=1) + last_week = today - timedelta(days=7) + df = pd.DataFrame( + [ + { + "driver_id": 1001, + "event_timestamp": yesterday, + "created": now - timedelta(hours=2), + "conv_rate": 0.8, + "acc_rate": 0.5, + "avg_daily_trips": 15, + }, + { + "driver_id": 1001, + "event_timestamp": last_week, + "created": now - timedelta(hours=3), + "conv_rate": 0.75, + "acc_rate": 0.9, + "avg_daily_trips": 14, + }, + { + "driver_id": 1002, + "event_timestamp": yesterday, + "created": now - timedelta(hours=2), + "conv_rate": 0.7, + "acc_rate": 0.4, + "avg_daily_trips": 12, + }, + { + "driver_id": 1002, + "event_timestamp": yesterday - timedelta(days=1), + "created": now - timedelta(hours=2), + "conv_rate": 0.3, + "acc_rate": 0.6, + "avg_daily_trips": 12, + }, + ] + ) + ds = ray_environment.data_source_creator.create_data_source( + df, + ray_environment.feature_store.project, + timestamp_field="event_timestamp", + created_timestamp_column="created", + ) + return ds + + +def create_entity_df() -> pd.DataFrame: + entity_df = pd.DataFrame( + [ + {"driver_id": 1001, "event_timestamp": today}, + {"driver_id": 1002, "event_timestamp": today}, + ] + ) + return entity_df + + +def create_ray_environment(): + ray_config = IntegrationTestRepoConfig( + provider="local", + online_store_creator=RedisOnlineStoreCreator, + offline_store_creator=RayDataSourceCreator, + batch_engine={ + "type": "ray.engine", + "use_ray_cluster": False, + "max_workers": 2, + "enable_optimization": True, + }, + ) + ray_environment = construct_test_environment( + ray_config, None, entity_key_serialization_version=3 + ) + ray_environment.setup() + return ray_environment + + +@pytest.mark.integration +@pytest.mark.xdist_group(name="ray") +def test_ray_compute_engine_get_historical_features(): + """Test Ray compute engine historical feature retrieval.""" + ray_environment = create_ray_environment() + fs = ray_environment.feature_store + registry = fs.registry + data_source = create_feature_dataset(ray_environment) + + def transform_feature(df: pd.DataFrame) -> pd.DataFrame: + df["sum_conv_rate"] = df["sum_conv_rate"] * 2 + df["avg_acc_rate"] = df["avg_acc_rate"] * 2 + return df + + driver_stats_fv = BatchFeatureView( + name="driver_hourly_stats", + entities=[driver], + mode="pandas", + aggregations=[ + Aggregation(column="conv_rate", function="sum"), + Aggregation(column="acc_rate", function="avg"), + ], + udf=transform_feature, + udf_string="transform_feature", + ttl=timedelta(days=3), + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + Field(name="driver_id", dtype=Int32), + ], + online=False, + offline=False, + source=data_source, + ) + + entity_df = create_entity_df() + + try: + fs.apply([driver, driver_stats_fv]) + + # Build retrieval task + task = HistoricalRetrievalTask( + project=ray_environment.project, + entity_df=entity_df, + feature_view=driver_stats_fv, + full_feature_name=False, + registry=registry, + ) + + # Run RayComputeEngine + engine = RayComputeEngine( + repo_config=ray_environment.config, + offline_store=RayOfflineStore(), + online_store=MagicMock(), + ) + + ray_dag_retrieval_job = engine.get_historical_features(registry, task) + ray_dataset = cast(RayDAGRetrievalJob, ray_dag_retrieval_job).to_ray_dataset() + df_out = ray_dataset.to_pandas().sort_values("driver_id") + + # Assert output + assert df_out.driver_id.to_list() == [1001, 1002] + assert abs(df_out["sum_conv_rate"].to_list()[0] - 1.6) < 1e-6 + assert abs(df_out["sum_conv_rate"].to_list()[1] - 2.0) < 1e-6 + assert abs(df_out["avg_acc_rate"].to_list()[0] - 1.0) < 1e-6 + assert abs(df_out["avg_acc_rate"].to_list()[1] - 1.0) < 1e-6 + + finally: + ray_environment.teardown() + if ray.is_initialized(): + ray.shutdown() + + +@pytest.mark.integration +@pytest.mark.xdist_group(name="ray") +def test_ray_compute_engine_materialize(): + """Test Ray compute engine materialization.""" + ray_environment = create_ray_environment() + fs = ray_environment.feature_store + registry = fs.registry + data_source = create_feature_dataset(ray_environment) + + def transform_feature(df: pd.DataFrame) -> pd.DataFrame: + df["sum_conv_rate"] = df["sum_conv_rate"] * 2 + df["avg_acc_rate"] = df["avg_acc_rate"] * 2 + return df + + driver_stats_fv = BatchFeatureView( + name="driver_hourly_stats", + entities=[driver], + mode="pandas", + aggregations=[ + Aggregation(column="conv_rate", function="sum"), + Aggregation(column="acc_rate", function="avg"), + ], + udf=transform_feature, + udf_string="transform_feature", + ttl=timedelta(days=3), + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + Field(name="driver_id", dtype=Int32), + ], + online=True, + offline=False, + source=data_source, + ) + + def tqdm_builder(length): + return tqdm(length, ncols=100) + + try: + fs.apply([driver, driver_stats_fv]) + + # Build materialization task + task = MaterializationTask( + project=ray_environment.project, + feature_view=driver_stats_fv, + start_time=now - timedelta(days=2), + end_time=now, + tqdm_builder=tqdm_builder, + ) + + # Run RayComputeEngine + engine = RayComputeEngine( + repo_config=ray_environment.config, + offline_store=RayOfflineStore(), + online_store=MagicMock(), + ) + + ray_materialize_jobs = engine.materialize(registry, task) + + assert len(ray_materialize_jobs) == 1 + assert ray_materialize_jobs[0].status() == MaterializationJobStatus.SUCCEEDED + + # Additional assertions can be added here for online store checks + + finally: + ray_environment.teardown() + if ray.is_initialized(): + ray.shutdown() + + +@pytest.mark.integration +@pytest.mark.xdist_group(name="ray") +def test_ray_compute_engine_config(): + """Test Ray compute engine configuration.""" + config = RayComputeEngineConfig( + type="ray.engine", + use_ray_cluster=True, + ray_address="ray://localhost:10001", + broadcast_join_threshold_mb=200, + enable_distributed_joins=True, + max_parallelism_multiplier=4, + target_partition_size_mb=128, + window_size_for_joins="2H", + max_workers=4, + enable_optimization=True, + execution_timeout_seconds=3600, + ) + + assert config.type == "ray.engine" + assert config.use_ray_cluster is True + assert config.ray_address == "ray://localhost:10001" + assert config.broadcast_join_threshold_mb == 200 + assert config.window_size_timedelta == timedelta(hours=2) diff --git a/sdk/python/tests/unit/__init__.py b/sdk/python/tests/unit/__init__.py new file mode 100644 index 00000000000..ea3f8b923c2 --- /dev/null +++ b/sdk/python/tests/unit/__init__.py @@ -0,0 +1 @@ +"""Unit tests package.""" diff --git a/sdk/python/tests/unit/infra/compute_engines/__init__.py b/sdk/python/tests/unit/infra/compute_engines/__init__.py new file mode 100644 index 00000000000..b1587145566 --- /dev/null +++ b/sdk/python/tests/unit/infra/compute_engines/__init__.py @@ -0,0 +1 @@ +"""Compute engines unit tests package.""" diff --git a/sdk/python/tests/unit/infra/compute_engines/ray_compute/__init__.py b/sdk/python/tests/unit/infra/compute_engines/ray_compute/__init__.py new file mode 100644 index 00000000000..2734c36c704 --- /dev/null +++ b/sdk/python/tests/unit/infra/compute_engines/ray_compute/__init__.py @@ -0,0 +1 @@ +"""Ray compute engine unit tests.""" diff --git a/sdk/python/tests/unit/infra/compute_engines/ray_compute/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/ray_compute/test_nodes.py new file mode 100644 index 00000000000..0cd3f7ca4a0 --- /dev/null +++ b/sdk/python/tests/unit/infra/compute_engines/ray_compute/test_nodes.py @@ -0,0 +1,346 @@ +from datetime import datetime, timedelta + +import pandas as pd +import pytest +import ray + +from feast.aggregation import Aggregation +from feast.infra.compute_engines.dag.context import ColumnInfo +from feast.infra.compute_engines.dag.model import DAGFormat +from feast.infra.compute_engines.dag.node import DAGNode +from feast.infra.compute_engines.dag.value import DAGValue +from feast.infra.compute_engines.ray.config import RayComputeEngineConfig +from feast.infra.compute_engines.ray.nodes import ( + RayAggregationNode, + RayDedupNode, + RayFilterNode, + RayJoinNode, + RayReadNode, + RayTransformationNode, + RayWriteNode, +) + + +class DummyInputNode(DAGNode): + def __init__(self, name, output): + super().__init__(name) + self._output = output + + def execute(self, context): + return self._output + + +class DummyFeatureView: + name = "dummy" + online = False + offline = False + + +class DummySource: + pass + + +class DummyRetrievalJob: + def __init__(self, ray_dataset): + self._ray_dataset = ray_dataset + + def to_ray_dataset(self): + return self._ray_dataset + + +@pytest.fixture(scope="session") +def ray_session(): + """Initialize Ray session for testing.""" + if not ray.is_initialized(): + ray.init(num_cpus=2, ignore_reinit_error=True, include_dashboard=False) + yield ray + ray.shutdown() + + +@pytest.fixture +def ray_config(): + """Create Ray compute engine configuration for testing.""" + return RayComputeEngineConfig( + type="ray.engine", + use_ray_cluster=False, + max_workers=2, + enable_optimization=True, + broadcast_join_threshold_mb=50, + target_partition_size_mb=32, + ) + + +@pytest.fixture +def mock_context(): + class DummyOfflineStore: + def offline_write_batch(self, *args, **kwargs): + pass + + class DummyContext: + def __init__(self): + self.registry = None + self.store = None + self.project = "test_project" + self.entity_data = None + self.config = None + self.node_outputs = {} + self.offline_store = DummyOfflineStore() + + return DummyContext() + + +@pytest.fixture +def sample_data(): + """Create sample data for testing.""" + return pd.DataFrame( + [ + { + "driver_id": 1001, + "event_timestamp": datetime.now() - timedelta(hours=1), + "created": datetime.now() - timedelta(hours=2), + "conv_rate": 0.8, + "acc_rate": 0.5, + "avg_daily_trips": 15, + }, + { + "driver_id": 1002, + "event_timestamp": datetime.now() - timedelta(hours=2), + "created": datetime.now() - timedelta(hours=3), + "conv_rate": 0.7, + "acc_rate": 0.4, + "avg_daily_trips": 12, + }, + { + "driver_id": 1001, + "event_timestamp": datetime.now() - timedelta(hours=3), + "created": datetime.now() - timedelta(hours=4), + "conv_rate": 0.75, + "acc_rate": 0.9, + "avg_daily_trips": 14, + }, + ] + ) + + +@pytest.fixture +def column_info(): + """Create a sample ColumnInfo for testing Ray nodes.""" + return ColumnInfo( + join_keys=["driver_id"], + feature_cols=["conv_rate", "acc_rate", "avg_daily_trips"], + ts_col="event_timestamp", + created_ts_col="created", + field_mapping=None, + ) + + +def test_ray_read_node(ray_session, ray_config, mock_context, sample_data, column_info): + """Test RayReadNode functionality.""" + ray_dataset = ray.data.from_pandas(sample_data) + mock_source = DummySource() + node = RayReadNode( + name="read", + source=mock_source, + column_info=column_info, + config=ray_config, + ) + mock_context.registry = None + mock_context.store = None + mock_context.offline_store = None + mock_retrieval_job = DummyRetrievalJob(ray_dataset) + import feast.infra.compute_engines.ray.nodes as ray_nodes + + ray_nodes.create_offline_store_retrieval_job = lambda **kwargs: mock_retrieval_job + result = node.execute(mock_context) + assert isinstance(result, DAGValue) + assert result.format == DAGFormat.RAY + result_df = result.data.to_pandas() + assert len(result_df) == 3 + assert "driver_id" in result_df.columns + assert "conv_rate" in result_df.columns + + +def test_ray_aggregation_node( + ray_session, ray_config, mock_context, sample_data, column_info +): + """Test RayAggregationNode functionality.""" + ray_dataset = ray.data.from_pandas(sample_data) + input_value = DAGValue(data=ray_dataset, format=DAGFormat.RAY) + dummy_node = DummyInputNode("input_node", input_value) + node = RayAggregationNode( + name="aggregation", + aggregations=[ + Aggregation(column="conv_rate", function="sum"), + Aggregation(column="acc_rate", function="avg"), + ], + group_by_keys=["driver_id"], + timestamp_col="event_timestamp", + config=ray_config, + ) + node.add_input(dummy_node) + mock_context.node_outputs = {"input_node": input_value} + result = node.execute(mock_context) + assert isinstance(result, DAGValue) + assert result.format == DAGFormat.RAY + result_df = result.data.to_pandas() + assert len(result_df) == 2 + assert "driver_id" in result_df.columns + assert "sum_conv_rate" in result_df.columns + assert "avg_acc_rate" in result_df.columns + + +def test_ray_join_node(ray_session, ray_config, mock_context, sample_data, column_info): + """Test RayJoinNode functionality.""" + entity_data = pd.DataFrame( + [ + {"driver_id": 1001, "event_timestamp": datetime.now()}, + {"driver_id": 1002, "event_timestamp": datetime.now()}, + ] + ) + feature_dataset = ray.data.from_pandas(sample_data) + feature_value = DAGValue(data=feature_dataset, format=DAGFormat.RAY) + dummy_node = DummyInputNode("feature_node", feature_value) + node = RayJoinNode( + name="join", + column_info=column_info, + config=ray_config, + ) + node.add_input(dummy_node) + mock_context.node_outputs = {"feature_node": feature_value} + mock_context.entity_df = entity_data + result = node.execute(mock_context) + assert isinstance(result, DAGValue) + assert result.format == DAGFormat.RAY + result_df = result.data.to_pandas() + assert len(result_df) >= 2 + assert "driver_id" in result_df.columns + + +def test_ray_transformation_node( + ray_session, ray_config, mock_context, sample_data, column_info +): + """Test RayTransformationNode functionality.""" + ray_dataset = ray.data.from_pandas(sample_data) + + def transform_feature(df: pd.DataFrame) -> pd.DataFrame: + df["conv_rate_doubled"] = df["conv_rate"] * 2 + return df + + input_value = DAGValue(data=ray_dataset, format=DAGFormat.RAY) + dummy_node = DummyInputNode("input_node", input_value) + node = RayTransformationNode( + name="transformation", + transformation=transform_feature, + config=ray_config, + ) + node.add_input(dummy_node) + mock_context.node_outputs = {"input_node": input_value} + result = node.execute(mock_context) + assert isinstance(result, DAGValue) + assert result.format == DAGFormat.RAY + result_df = result.data.to_pandas() + assert len(result_df) == 3 + assert "conv_rate_doubled" in result_df.columns + assert ( + result_df["conv_rate_doubled"].iloc[0] == sample_data["conv_rate"].iloc[0] * 2 + ) + + +def test_ray_filter_node( + ray_session, ray_config, mock_context, sample_data, column_info +): + """Test RayFilterNode functionality.""" + ray_dataset = ray.data.from_pandas(sample_data) + input_value = DAGValue(data=ray_dataset, format=DAGFormat.RAY) + dummy_node = DummyInputNode("input_node", input_value) + node = RayFilterNode( + name="filter", + column_info=column_info, + config=ray_config, + ttl=timedelta(hours=2), + filter_condition=None, + ) + node.add_input(dummy_node) + mock_context.node_outputs = {"input_node": input_value} + result = node.execute(mock_context) + assert isinstance(result, DAGValue) + assert result.format == DAGFormat.RAY + result_df = result.data.to_pandas() + assert len(result_df) <= 3 + assert "event_timestamp" in result_df.columns + + +def test_ray_dedup_node( + ray_session, ray_config, mock_context, sample_data, column_info +): + """Test RayDedupNode functionality.""" + duplicated_data = pd.concat([sample_data, sample_data.iloc[:1]], ignore_index=True) + ray_dataset = ray.data.from_pandas(duplicated_data) + input_value = DAGValue(data=ray_dataset, format=DAGFormat.RAY) + dummy_node = DummyInputNode("input_node", input_value) + node = RayDedupNode( + name="dedup", + column_info=column_info, + config=ray_config, + ) + node.add_input(dummy_node) + mock_context.node_outputs = {"input_node": input_value} + result = node.execute(mock_context) + assert isinstance(result, DAGValue) + assert result.format == DAGFormat.RAY + result_df = result.data.to_pandas() + assert len(result_df) == 2 # Should remove the duplicate row + assert "driver_id" in result_df.columns + + +def test_ray_write_node( + ray_session, ray_config, mock_context, sample_data, column_info +): + """Test RayWriteNode functionality.""" + ray_dataset = ray.data.from_pandas(sample_data) + input_value = DAGValue(data=ray_dataset, format=DAGFormat.RAY) + dummy_node = DummyInputNode("input_node", input_value) + + mock_feature_view = DummyFeatureView() + node = RayWriteNode( + name="write", + feature_view=mock_feature_view, + config=ray_config, + ) + node.add_input(dummy_node) + mock_context.node_outputs = {"input_node": input_value} + result = node.execute(mock_context) + assert isinstance(result, DAGValue) + assert result.format == DAGFormat.RAY + result_df = result.data.to_pandas() + assert len(result_df) == 3 + assert "driver_id" in result_df.columns + + +def test_ray_config_validation(): + """Test Ray configuration validation.""" + # Test valid configuration + config = RayComputeEngineConfig( + type="ray.engine", + use_ray_cluster=False, + max_workers=4, + enable_optimization=True, + broadcast_join_threshold_mb=100, + target_partition_size_mb=64, + window_size_for_joins="30min", + ) + + assert config.type == "ray.engine" + assert config.max_workers == 4 + assert config.window_size_timedelta == timedelta(minutes=30) + + # Test window size parsing + config_hours = RayComputeEngineConfig(window_size_for_joins="2H") + assert config_hours.window_size_timedelta == timedelta(hours=2) + + config_seconds = RayComputeEngineConfig(window_size_for_joins="30s") + assert config_seconds.window_size_timedelta == timedelta(seconds=30) + + # Test invalid window size defaults to 1 hour + config_invalid = RayComputeEngineConfig(window_size_for_joins="invalid") + assert config_invalid.window_size_timedelta == timedelta(hours=1) From 1da9c94011873b712daddb7afe3a8b1c75edad63 Mon Sep 17 00:00:00 2001 From: ntkathole Date: Sat, 26 Jul 2025 17:45:42 +0530 Subject: [PATCH 08/10] fix: Fixed logic for source/derived feature views Signed-off-by: ntkathole --- Makefile | 2 +- sdk/python/feast/feature_view.py | 3 +- sdk/python/feast/feature_view_utils.py | 229 +++++++++++++ .../infra/compute_engines/ray/compute.py | 41 ++- .../feast/infra/compute_engines/ray/config.py | 4 +- .../compute_engines/ray/feature_builder.py | 251 +++++++++----- .../feast/infra/compute_engines/ray/nodes.py | 178 +++++++--- .../contrib/ray_offline_store/ray.py | 131 +++++--- .../compute_engines/ray_compute/conftest.py | 26 ++ .../ray_compute/ray_shared_utils.py | 177 ++++++++++ .../ray_compute/repo_configuration.py | 46 --- .../ray_compute/test_compute.py | 216 +++--------- .../ray_compute/test_source_feature_views.py | 308 ++++++++++++++++++ .../compute_engines/ray_compute/test_nodes.py | 25 -- 14 files changed, 1210 insertions(+), 427 deletions(-) create mode 100644 sdk/python/feast/feature_view_utils.py create mode 100644 sdk/python/tests/integration/compute_engines/ray_compute/conftest.py create mode 100644 sdk/python/tests/integration/compute_engines/ray_compute/ray_shared_utils.py create mode 100644 sdk/python/tests/integration/compute_engines/ray_compute/test_source_feature_views.py diff --git a/Makefile b/Makefile index 7bc2570245c..20220164e87 100644 --- a/Makefile +++ b/Makefile @@ -322,7 +322,7 @@ test-python-universal-ray-offline: ## Run Python Ray offline store integration t test-python-ray-compute-engine: ## Run Python Ray compute engine tests PYTHONPATH='.' \ - python -m pytest --integration \ + python -m pytest -v --integration \ sdk/python/tests/integration/compute_engines/ray_compute/ test-python-universal-postgres-online: ## Run Python Postgres integration tests diff --git a/sdk/python/feast/feature_view.py b/sdk/python/feast/feature_view.py index b77737a8bd5..4a086f1b99b 100644 --- a/sdk/python/feast/feature_view.py +++ b/sdk/python/feast/feature_view.py @@ -510,7 +510,8 @@ def _from_proto_internal( if feature_view_proto.spec.ttl.ToNanoseconds() == 0 else feature_view_proto.spec.ttl.ToTimedelta() ), - source=batch_source if batch_source else source_views, + source=source_views if source_views else batch_source, + sink_source=batch_source if source_views else None, ) if stream_source: feature_view.stream_source = stream_source diff --git a/sdk/python/feast/feature_view_utils.py b/sdk/python/feast/feature_view_utils.py new file mode 100644 index 00000000000..daf28e09dec --- /dev/null +++ b/sdk/python/feast/feature_view_utils.py @@ -0,0 +1,229 @@ +""" +Utility functions for feature view operations including source resolution. +""" + +import logging +import typing +from dataclasses import dataclass +from typing import Callable, Optional + +if typing.TYPE_CHECKING: + from feast.data_source import DataSource + from feast.feature_view import FeatureView + from feast.repo_config import RepoConfig + +logger = logging.getLogger(__name__) + + +@dataclass +class FeatureViewSourceInfo: + """Information about a feature view's data source resolution.""" + + data_source: "DataSource" + source_type: str + has_transformation: bool + transformation_func: Optional[Callable] = None + source_description: str = "" + + +def has_transformation(feature_view: "FeatureView") -> bool: + """Check if a feature view has transformations (UDF or feature_transformation).""" + return ( + getattr(feature_view, "udf", None) is not None + or getattr(feature_view, "feature_transformation", None) is not None + ) + + +def get_transformation_function(feature_view: "FeatureView") -> Optional[Callable]: + """Extract the transformation function from a feature view.""" + feature_transformation = getattr(feature_view, "feature_transformation", None) + if feature_transformation: + # Use feature_transformation if available (preferred) + if hasattr(feature_transformation, "udf") and callable( + feature_transformation.udf + ): + return feature_transformation.udf + + # Fallback to direct UDF + udf = getattr(feature_view, "udf", None) + if udf and callable(udf): + return udf + + return None + + +def find_original_source_view(feature_view: "FeatureView") -> "FeatureView": + """ + Recursively find the original source feature view that has a batch_source. + For derived feature views, this follows the source_views chain until it finds + a feature view with an actual DataSource (batch_source). + """ + current_view = feature_view + while hasattr(current_view, "source_views") and current_view.source_views: + if not current_view.source_views: + break + current_view = current_view.source_views[0] # Assuming single source for now + return current_view + + +def check_sink_source_exists(data_source: "DataSource") -> bool: + """ + Check if a sink_source file actually exists. + Args: + data_source: The DataSource to check + Returns: + bool: True if the source exists, False otherwise + """ + try: + import fsspec + + # Get the source path + if hasattr(data_source, "path"): + source_path = data_source.path + else: + source_path = str(data_source) + + fs, path_in_fs = fsspec.core.url_to_fs(source_path) + return fs.exists(path_in_fs) + except Exception as e: + logger.warning(f"Failed to check if source exists: {e}") + return False + + +def resolve_feature_view_source( + feature_view: "FeatureView", + config: Optional["RepoConfig"] = None, + is_materialization: bool = False, +) -> FeatureViewSourceInfo: + """ + Resolve the appropriate data source for a feature view. + + This handles the complex logic of determining whether to read from: + 1. sink_source (materialized data from parent views) + 2. batch_source (original data source) + 3. Recursive resolution for derived views + + Args: + feature_view: The feature view to resolve + config: Repository configuration (optional) + is_materialization: Whether this is during materialization (affects derived view handling) + + Returns: + FeatureViewSourceInfo: Information about the resolved source + """ + view_has_transformation = has_transformation(feature_view) + transformation_func = ( + get_transformation_function(feature_view) if view_has_transformation else None + ) + + # Check if this is a derived feature view (has source_views) + is_derived_view = ( + hasattr(feature_view, "source_views") and feature_view.source_views + ) + + if not is_derived_view: + # Regular feature view - use its batch_source directly + return FeatureViewSourceInfo( + data_source=feature_view.batch_source, + source_type="batch_source", + has_transformation=view_has_transformation, + transformation_func=transformation_func, + source_description=f"Direct batch_source for {feature_view.name}", + ) + + # This is a derived feature view - need to resolve parent source + if not feature_view.source_views: + raise ValueError( + f"Derived feature view {feature_view.name} has no source_views" + ) + parent_view = feature_view.source_views[0] # Assuming single source for now + + # For derived views: distinguish between materialization and historical retrieval + if ( + hasattr(parent_view, "sink_source") + and parent_view.sink_source + and is_materialization + ): + # During materialization, try to use sink_source if it exists + if check_sink_source_exists(parent_view.sink_source): + logger.debug( + f"Materialization: Using parent {parent_view.name} sink_source" + ) + return FeatureViewSourceInfo( + data_source=parent_view.sink_source, + source_type="sink_source", + has_transformation=view_has_transformation, + transformation_func=transformation_func, + source_description=f"Parent {parent_view.name} sink_source for derived view {feature_view.name}", + ) + else: + logger.info( + f"Parent {parent_view.name} sink_source doesn't exist during materialization" + ) + + # Check if parent is also a derived view first - if so, recursively resolve to original source + if hasattr(parent_view, "source_views") and parent_view.source_views: + # Parent is also a derived view - recursively find original source + original_source_view = find_original_source_view(parent_view) + return FeatureViewSourceInfo( + data_source=original_source_view.batch_source, + source_type="original_source", + has_transformation=view_has_transformation, + transformation_func=transformation_func, + source_description=f"Original source {original_source_view.name} batch_source for derived view {feature_view.name} (via {parent_view.name})", + ) + elif hasattr(parent_view, "batch_source") and parent_view.batch_source: + # Parent has a direct batch_source, use it + return FeatureViewSourceInfo( + data_source=parent_view.batch_source, + source_type="batch_source", + has_transformation=view_has_transformation, + transformation_func=transformation_func, + source_description=f"Parent {parent_view.name} batch_source for derived view {feature_view.name}", + ) + else: + # No valid source found + raise ValueError( + f"Unable to resolve data source for derived feature view {feature_view.name} via parent {parent_view.name}" + ) + + +def resolve_feature_view_source_with_fallback( + feature_view: "FeatureView", + config: Optional["RepoConfig"] = None, + is_materialization: bool = False, +) -> FeatureViewSourceInfo: + """ + Resolve feature view source with fallback error handling. + + This version includes additional error handling and fallback logic + for cases where the primary resolution fails. + """ + try: + return resolve_feature_view_source(feature_view, config, is_materialization) + except Exception as e: + logger.warning(f"Primary source resolution failed for {feature_view.name}: {e}") + + # Fallback: try to find any available source + if hasattr(feature_view, "batch_source") and feature_view.batch_source: + return FeatureViewSourceInfo( + data_source=feature_view.batch_source, + source_type="fallback_batch_source", + has_transformation=has_transformation(feature_view), + transformation_func=get_transformation_function(feature_view), + source_description=f"Fallback batch_source for {feature_view.name}", + ) + elif hasattr(feature_view, "source_views") and feature_view.source_views: + # Try the original source view as last resort + original_view = find_original_source_view(feature_view) + return FeatureViewSourceInfo( + data_source=original_view.batch_source, + source_type="fallback_original_source", + has_transformation=has_transformation(feature_view), + transformation_func=get_transformation_function(feature_view), + source_description=f"Fallback original source {original_view.name} for {feature_view.name}", + ) + else: + raise ValueError( + f"Unable to resolve any data source for feature view {feature_view.name}" + ) diff --git a/sdk/python/feast/infra/compute_engines/ray/compute.py b/sdk/python/feast/infra/compute_engines/ray/compute.py index 3363d483a06..0cd7cddccfd 100644 --- a/sdk/python/feast/infra/compute_engines/ray/compute.py +++ b/sdk/python/feast/infra/compute_engines/ray/compute.py @@ -199,14 +199,43 @@ def _materialize_from_offline_store( end_date=end_date, ) - # Convert to Arrow Table and write to online store + # Convert to Arrow Table and write to online/offline stores arrow_table = retrieval_job.to_arrow() - # TODO: Implement proper online store writing with correct data format conversion - # self.online_store.online_write_batch(...) - logger.debug( - f"Materialization completed, arrow table has {arrow_table.num_rows} rows" - ) + # Write to online store if enabled + if getattr(feature_view, "online", False): + # TODO: Implement proper online store writing with correct data format conversion + logger.debug( + f"Online store writing not implemented yet for {arrow_table.num_rows} rows" + ) + + # Write to offline store if enabled (this handles sink_source automatically for derived views) + if getattr(feature_view, "offline", False): + self.offline_store.offline_write_batch( + config=self.repo_config, + feature_view=feature_view, + table=arrow_table, + progress=lambda x: None, + ) + + # For derived views, also ensure data is written to sink_source if it exists + # This is critical for feature view chaining to work properly + sink_source = getattr(feature_view, "sink_source", None) + if sink_source is not None: + logger.debug( + f"Writing derived view {feature_view.name} to sink_source: {sink_source.path}" + ) + + # Write to sink_source using Ray data + try: + # Convert arrow table to pandas then to ray dataset + df = arrow_table.to_pandas() + ray_dataset = ray.data.from_pandas(df) + ray_dataset.write_parquet(sink_source.path) + except Exception as e: + logger.error( + f"Failed to write to sink_source {sink_source.path}: {e}" + ) return RayMaterializationJob( job_id=job_id, status=MaterializationJobStatus.SUCCEEDED, diff --git a/sdk/python/feast/infra/compute_engines/ray/config.py b/sdk/python/feast/infra/compute_engines/ray/config.py index 5fc66b49659..0e25320651f 100644 --- a/sdk/python/feast/infra/compute_engines/ray/config.py +++ b/sdk/python/feast/infra/compute_engines/ray/config.py @@ -1,7 +1,7 @@ """Configuration for Ray compute engine.""" from datetime import timedelta -from typing import Dict, Literal, Optional +from typing import Any, Dict, Literal, Optional from pydantic import StrictStr @@ -39,7 +39,7 @@ class RayComputeEngineConfig(FeastConfigBaseModel): window_size_for_joins: str = "1H" """Window size for windowed temporal joins""" - ray_conf: Optional[Dict[str, str]] = None + ray_conf: Optional[Dict[str, Any]] = None """Ray configuration parameters""" # Additional configuration options diff --git a/sdk/python/feast/infra/compute_engines/ray/feature_builder.py b/sdk/python/feast/infra/compute_engines/ray/feature_builder.py index 7f49accd0a0..8d6003ff3c8 100644 --- a/sdk/python/feast/infra/compute_engines/ray/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/ray/feature_builder.py @@ -1,13 +1,18 @@ import logging -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast +from feast.feature_view_utils import resolve_feature_view_source_with_fallback from feast.infra.common.materialization_job import MaterializationTask from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.algorithms.topo import topological_sort +from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.plan import ExecutionPlan from feast.infra.compute_engines.feature_builder import FeatureBuilder +from feast.infra.compute_engines.ray.config import RayComputeEngineConfig from feast.infra.compute_engines.ray.nodes import ( RayAggregationNode, RayDedupNode, + RayDerivedReadNode, RayFilterNode, RayJoinNode, RayReadNode, @@ -38,17 +43,21 @@ def __init__( super().__init__(registry, feature_view, task) self.config = config self.is_historical_retrieval = isinstance(task, HistoricalRetrievalTask) + self.is_materialization = isinstance(task, MaterializationTask) def build_source_node(self, view): """Build the source node for reading feature data.""" - source = view.batch_source + + source_info = resolve_feature_view_source_with_fallback( + view, config=None, is_materialization=self.is_materialization + ) start_time = self.task.start_time end_time = self.task.end_time column_info = self.get_column_info(view) node = RayReadNode( name="source", - source=source, + source=source_info.data_source, column_info=column_info, config=self.config, start_time=start_time, @@ -56,32 +65,49 @@ def build_source_node(self, view): ) self.nodes.append(node) - logger.debug(f"Built source node for {source}") + return node + + def build_aggregation_node(self, view, input_node: DAGNode) -> DAGNode: + """Build aggregation node for Ray.""" + agg_specs = getattr(view, "aggregations", []) + if not agg_specs: + raise ValueError(f"No aggregations found for {view.name}") + + group_by_keys = view.entities + timestamp_col = getattr(view.batch_source, "timestamp_field", "event_timestamp") + + node = RayAggregationNode( + name="aggregation", + aggregations=agg_specs, + group_by_keys=group_by_keys, + timestamp_col=timestamp_col, + config=self.config, + ) + node.add_input(input_node) + + self.nodes.append(node) return node def build_join_node(self, view, input_nodes): - """Build the join node for entity-feature joining.""" + """Build the join node for combining multiple feature sources.""" column_info = self.get_column_info(view) + node = RayJoinNode( name="join", column_info=column_info, config=self.config, - # Pass entity_df information if this is a historical retrieval is_historical_retrieval=self.is_historical_retrieval, ) for input_node in input_nodes: node.add_input(input_node) + self.nodes.append(node) - logger.debug("Built join node") return node def build_filter_node(self, view, input_node): """Build the filter node for TTL and custom filtering.""" - filter_expr = None - if hasattr(view, "filter"): - filter_expr = view.filter - ttl = getattr(view, "ttl", None) + filter_condition = getattr(view, "filter", None) column_info = self.get_column_info(view) node = RayFilterNode( @@ -89,112 +115,191 @@ def build_filter_node(self, view, input_node): column_info=column_info, config=self.config, ttl=ttl, - filter_condition=filter_expr, + filter_condition=filter_condition, ) - node.add_input(input_node) - self.nodes.append(node) - logger.debug(f"Built filter node with TTL: {ttl}") - return node - - def build_aggregation_node(self, view, input_node): - """Build the aggregation node for feature aggregations.""" - if not hasattr(view, "aggregations"): - raise ValueError("Feature view does not have aggregations") - - aggregations = view.aggregations - group_by_keys = view.entities - - # Get timestamp field from batch source - timestamp_field = getattr( - view.batch_source, "timestamp_field", "event_timestamp" - ) - - node = RayAggregationNode( - name="aggregation", - aggregations=aggregations, - group_by_keys=group_by_keys, - timestamp_col=timestamp_field, - config=self.config, - ) - node.add_input(input_node) self.nodes.append(node) - logger.debug(f"Built aggregation node with {len(aggregations)} aggregations") return node def build_dedup_node(self, view, input_node): - """Build the deduplication node for removing duplicates.""" + """Build the deduplication node for removing duplicate records.""" column_info = self.get_column_info(view) + node = RayDedupNode( name="dedup", column_info=column_info, config=self.config, ) - node.add_input(input_node) + self.nodes.append(node) - logger.debug("Built dedup node") return node def build_transformation_node(self, view, input_nodes): - """Build the transformation node for feature transformations.""" - transformation = None - - # Check for feature_transformation first - if hasattr(view, "feature_transformation") and view.feature_transformation: - transformation = view.feature_transformation - # For BatchFeatureView, also check for direct UDF - elif hasattr(view, "udf") and view.udf: - transformation = view.udf - else: - raise ValueError("Feature view does not have feature transformation or UDF") + """Build the transformation node for user-defined transformations.""" + feature_transformation = getattr(view, "feature_transformation", None) + udf = getattr(view, "udf", None) + + transformation = feature_transformation or udf + if not transformation: + raise ValueError(f"No feature transformation found for {view.name}") node = RayTransformationNode( name="transformation", transformation=transformation, config=self.config, ) - for input_node in input_nodes: node.add_input(input_node) + self.nodes.append(node) - transformation_name = getattr( - transformation, "name", getattr(transformation, "__name__", "unknown") - ) - logger.debug(f"Built transformation node: {transformation_name}") return node def build_output_nodes(self, view, final_node): - """Build the output node for writing results.""" + """Build the output node for writing processed features.""" node = RayWriteNode( name="output", feature_view=view, - config=self.config, + inputs=[final_node], ) - node.add_input(final_node) self.nodes.append(node) - logger.debug("Built output node") return node def build_validation_node(self, view, input_node): - """Build the validation node for data quality checks.""" - # For now, validation is handled in the retrieval job - # This could be extended to include Ray-specific validation logic - logger.debug("Validation node not implemented yet") + """Build the validation node for feature validation.""" + # TODO: Implement validation logic return input_node - def build(self) -> ExecutionPlan: - """Build execution plan with optimized order for aggregation scenarios.""" + def _build(self, view, input_nodes: Optional[List[DAGNode]]) -> DAGNode: + """Override _build to handle derived views during materialization.""" + # Step 1: build source node + if view.data_source: + last_node = self.build_source_node(view) + + # If source node is None (derived view during materialization), use input nodes + if last_node is None and input_nodes: + if self._should_transform(view): + last_node = self.build_transformation_node(view, input_nodes) + else: + last_node = self.build_join_node(view, input_nodes) + elif last_node is not None: + if self._should_transform(view): + # Transform applied to the source data + last_node = self.build_transformation_node(view, [last_node]) + + # If there are input nodes, transform or join them + elif input_nodes: + # User-defined transform handles the merging of input views + if self._should_transform(view): + last_node = self.build_transformation_node(view, input_nodes) + # Default join + else: + last_node = self.build_join_node(view, input_nodes) + else: + raise ValueError(f"FeatureView {view.name} has no valid source or inputs") + + # Skip subsequent steps if last_node is None + if last_node is None: + raise ValueError(f"Failed to build processing node for {view.name}") + + # Step 2: filter + last_node = self.build_filter_node(view, last_node) + + # Step 3: aggregate or dedupe + if self._should_aggregate(view): + last_node = self.build_aggregation_node(view, last_node) + elif self._should_dedupe(view): + last_node = self.build_dedup_node(view, last_node) + + # Step 4: validate + if self._should_validate(view): + last_node = self.build_validation_node(view, last_node) + + return last_node - # For historical retrieval with aggregations, use a different execution order + def build(self) -> ExecutionPlan: + """Build execution plan with support for derived feature views and sink_source writing.""" if self.is_historical_retrieval and self._should_aggregate(self.feature_view): return self._build_aggregation_optimized_plan() - # Use the default build logic for other scenarios + if self.is_materialization: + return self._build_materialization_plan() + return super().build() + def _build_materialization_plan(self) -> ExecutionPlan: + """Build execution plan for materialization with intermediate sink writes.""" + logger.info(f"Building materialization plan for {self.feature_view.name}") + + # Step 1: Topo sort the FeatureViewNode DAG (Logical DAG) + logical_nodes = self.feature_resolver.topological_sort(self.dag_root) + logger.info( + f"Logical nodes in topo order: {[node.view.name for node in logical_nodes]}" + ) + + # Step 2: For each FeatureView, build its corresponding execution DAGNode and write node + # Build them in dependency order to ensure proper execution + view_to_write_node: Dict[str, RayWriteNode] = {} + + for i, logical_node in enumerate(logical_nodes): + view = logical_node.view + logger.info( + f"Building nodes for view {view.name} (step {i + 1}/{len(logical_nodes)})" + ) + + # For derived views, we need to ensure parent views are materialized first + # So we create a processing chain that depends on parent write nodes + parent_write_nodes = [] + processing_node: DAGNode + if hasattr(view, "source_views") and view.source_views: + # This is a derived view - collect parent write nodes as dependencies + for parent in logical_node.inputs: + if parent.view.name in view_to_write_node: + parent_write_nodes.append(view_to_write_node[parent.view.name]) + + if parent_write_nodes: + # For derived views, create a simple passthrough node that depends on parents + # This ensures the derived view processing only starts after parents are materialized + processing_node = RayDerivedReadNode( + name=f"{view.name}:derived_read", + feature_view=view, + parent_dependencies=cast(List[DAGNode], parent_write_nodes), + config=self.config, + column_info=self.get_column_info(view), + is_materialization=self.is_materialization, + ) + self.nodes.append(processing_node) + else: + # Parent not yet built - this shouldn't happen in topo order + raise ValueError(f"Parent views for {view.name} not yet built") + else: + # Regular view - build normal processing chain + processing_node = self._build(view, None) + + # Create a write node for this view + write_node = RayWriteNode( + name=f"{view.name}:write", + feature_view=view, + inputs=[processing_node], + ) + + view_to_write_node[view.name] = write_node + logger.info(f"Created write node for {view.name}") + + # Step 3: The final write node is the one for the top-level feature view + final_node = view_to_write_node[self.feature_view.name] + + # Step 4: Topo sort the final DAG from the output node (Physical DAG) + sorted_nodes = topological_sort(final_node) + + # Step 5: Update self.nodes to include all nodes for the execution plan + self.nodes = sorted_nodes + + # Step 6: Return sorted execution plan + return ExecutionPlan(sorted_nodes) + def _build_aggregation_optimized_plan(self) -> ExecutionPlan: """Build execution plan optimized for aggregation scenarios.""" @@ -214,11 +319,7 @@ def _build_aggregation_optimized_plan(self) -> ExecutionPlan: if self._should_transform(self.feature_view): last_node = self.build_transformation_node(self.feature_view, [last_node]) - # 6. Validation if needed - if self._should_validate(self.feature_view): - last_node = self.build_validation_node(self.feature_view, last_node) - - # 7. Output + # 6. Output last_node = self.build_output_nodes(self.feature_view, last_node) return ExecutionPlan(self.nodes) diff --git a/sdk/python/feast/infra/compute_engines/ray/nodes.py b/sdk/python/feast/infra/compute_engines/ray/nodes.py index 17c82fcc6d6..2d919da7274 100644 --- a/sdk/python/feast/infra/compute_engines/ray/nodes.py +++ b/sdk/python/feast/infra/compute_engines/ray/nodes.py @@ -1,14 +1,18 @@ import logging -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import List, Optional, Union +import dill import pandas as pd +import pyarrow as pa import ray from ray.data import Dataset -from feast import BatchFeatureView, StreamFeatureView +from feast import BatchFeatureView, FeatureView, StreamFeatureView from feast.aggregation import Aggregation from feast.data_source import DataSource +from feast.feature_view_utils import resolve_feature_view_source_with_fallback +from feast.infra.common.serde import SerializedArtifacts from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.dag.model import DAGFormat from feast.infra.compute_engines.dag.node import DAGNode @@ -51,7 +55,6 @@ def __init__( def execute(self, context: ExecutionContext) -> DAGValue: """Execute the read operation to load data from the offline store.""" try: - # Use utility function to create retrieval job retrieval_job = create_offline_store_retrieval_job( data_source=self.source, column_info=self.column_info, @@ -60,21 +63,16 @@ def execute(self, context: ExecutionContext) -> DAGValue: end_time=self.end_time, ) - # Convert to Ray Dataset if hasattr(retrieval_job, "to_ray_dataset"): - # If the retrieval job supports Ray datasets directly ray_dataset = retrieval_job.to_ray_dataset() else: - # Fall back to converting from Arrow/Pandas try: arrow_table = retrieval_job.to_arrow() ray_dataset = ray.data.from_arrow(arrow_table) except Exception: - # Ultimate fallback to pandas df = retrieval_job.to_df() ray_dataset = ray.data.from_pandas(df) - # Apply field mapping if needed field_mapping = getattr(self.source, "field_mapping", None) if field_mapping: ray_dataset = apply_field_mapping(ray_dataset, field_mapping) @@ -126,14 +124,12 @@ def execute(self, context: ExecutionContext) -> DAGValue: metadata={"joined": False}, ) - # Convert entity_df to Ray Dataset entity_df = context.entity_df if isinstance(entity_df, pd.DataFrame): entity_dataset = ray.data.from_pandas(entity_df) else: entity_dataset = entity_df - # Perform the join using Ray operations join_keys = self.column_info.join_keys timestamp_col = self.column_info.timestamp_column requested_feats = getattr(self.column_info, "feature_cols", []) @@ -146,7 +142,6 @@ def execute(self, context: ExecutionContext) -> DAGValue: if hasattr(sample_data[0], "columns"): feature_cols = sample_data[0].columns.tolist() else: - # Handle other data formats feature_cols = ( list(sample_data[0].keys()) if isinstance(sample_data[0], dict) @@ -256,11 +251,6 @@ def apply_filters(batch: pd.DataFrame) -> pd.DataFrame: if self.ttl: timestamp_col = self.column_info.timestamp_column if timestamp_col in filtered_batch.columns: - # Import necessary modules at the top of the function - from datetime import timezone - - import pandas as pd - # Convert to datetime if not already if not pd.api.types.is_datetime64_any_dtype( filtered_batch[timestamp_col] @@ -528,17 +518,8 @@ def __init__( config: RayComputeEngineConfig, ): super().__init__(name) - # Extract the UDF function to avoid serialization issues with PandasTransformation - if hasattr(transformation, "udf") and callable(transformation.udf): - self.transformation_udf = transformation.udf - self.transformation_name = getattr(transformation, "name", "unknown") - elif callable(transformation): - # Handle direct UDF functions - self.transformation_udf = transformation - self.transformation_name = getattr(transformation, "__name__", "unknown") - else: - self.transformation_udf = None - self.transformation_name = "unknown" + self.transformation = transformation + self.transformation_name = getattr(transformation, "name", "unknown") self.config = config def execute(self, context: ExecutionContext) -> DAGValue: @@ -547,21 +528,26 @@ def execute(self, context: ExecutionContext) -> DAGValue: input_value.assert_format(DAGFormat.RAY) dataset: Dataset = input_value.data - # Use the extracted UDF function directly - transformation_func = self.transformation_udf + transformation_serialized = None + if hasattr(self.transformation, "udf") and callable(self.transformation.udf): + transformation_serialized = dill.dumps(self.transformation.udf) + elif callable(self.transformation): + transformation_serialized = dill.dumps(self.transformation) - def apply_transformation(batch: pd.DataFrame) -> pd.DataFrame: - """Apply the transformation to the batch.""" + def apply_transformation_with_serialized_udf( + batch: pd.DataFrame, + ) -> pd.DataFrame: + """Apply the transformation using pre-serialized UDF.""" if batch.empty: return batch try: - # Apply the transformation function directly - if transformation_func and callable(transformation_func): + if transformation_serialized: + transformation_func = dill.loads(transformation_serialized) transformed_batch = transformation_func(batch) else: logger.warning( - "Transformation function not available, returning original batch" + "No serialized transformation available, returning original batch" ) transformed_batch = batch @@ -571,7 +557,7 @@ def apply_transformation(batch: pd.DataFrame) -> pd.DataFrame: return batch transformed_dataset = dataset.map_batches( - apply_transformation, batch_format="pandas" + apply_transformation_with_serialized_udf, batch_format="pandas" ) return DAGValue( @@ -584,20 +570,96 @@ def apply_transformation(batch: pd.DataFrame) -> pd.DataFrame: ) -class RayWriteNode(DAGNode): +class RayDerivedReadNode(DAGNode): """ - Ray node for writing results to online/offline stores. + Ray node for reading derived feature views after parent dependencies are materialized. + This node ensures that parent feature views are fully materialized before reading from their sink_source. """ def __init__( self, name: str, - feature_view: Union[BatchFeatureView, StreamFeatureView], + feature_view: FeatureView, + parent_dependencies: List[DAGNode], config: RayComputeEngineConfig, + column_info, + is_materialization: bool = True, ): super().__init__(name) self.feature_view = feature_view self.config = config + self.column_info = column_info + self.is_materialization = is_materialization + + # Add parent dependencies to ensure they execute first + for parent in parent_dependencies: + self.add_input(parent) + + def execute(self, context: ExecutionContext) -> DAGValue: + """Execute the derived read operation after parents are materialized.""" + # Wait for all parent dependencies to complete + # The inputs contain the parent write nodes which have completed materialization + self.get_input_values(context) + source_info = resolve_feature_view_source_with_fallback( + self.feature_view, config=None, is_materialization=self.is_materialization + ) + data_source = source_info.data_source + try: + retrieval_job = create_offline_store_retrieval_job( + data_source=data_source, + column_info=self.column_info, + context=context, + start_time=None, + end_time=None, + ) + + # Convert to Ray Dataset + if hasattr(retrieval_job, "to_ray_dataset"): + ray_dataset = retrieval_job.to_ray_dataset() + else: + try: + arrow_table = retrieval_job.to_arrow() + ray_dataset = ray.data.from_arrow(arrow_table) + except Exception: + df = retrieval_job.to_df() + ray_dataset = ray.data.from_pandas(df) + + # Apply field mapping if needed + field_mapping = getattr(data_source, "field_mapping", None) + if field_mapping: + ray_dataset = apply_field_mapping(ray_dataset, field_mapping) + + return DAGValue( + data=ray_dataset, + format=DAGFormat.RAY, + metadata={ + "source": "derived_from_parent", + "source_description": source_info.source_description, + "data_source_path": getattr(data_source, "path", "unknown"), + }, + ) + + except Exception as e: + logger.error( + f"Failed to read derived view {self.feature_view.name} from parent data source {getattr(data_source, 'path', 'unknown')}: {e}" + ) + raise + + +class RayWriteNode(DAGNode): + """ + Ray node for writing results to online/offline stores and sink_source paths. + This node handles writing intermediate results for derived feature views. + """ + + def __init__( + self, + name: str, + feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], + inputs=None, + ): + super().__init__(name, inputs=inputs) + self.feature_view = feature_view def execute(self, context: ExecutionContext) -> DAGValue: """Execute the write operation.""" @@ -605,47 +667,54 @@ def execute(self, context: ExecutionContext) -> DAGValue: input_value.assert_format(DAGFormat.RAY) dataset: Dataset = input_value.data - def write_batch(batch: pd.DataFrame) -> pd.DataFrame: - """Write each batch to the appropriate stores.""" + serialized_artifacts = SerializedArtifacts.serialize( + feature_view=self.feature_view, repo_config=context.repo_config + ) + + def write_batch_with_serialized_artifacts(batch: pd.DataFrame) -> pd.DataFrame: + """Write each batch using pre-serialized artifacts.""" if batch.empty: return batch try: - # Convert to Arrow Table for writing - import pyarrow as pa + ( + feature_view, + online_store, + offline_store, + repo_config, + ) = serialized_artifacts.unserialize() arrow_table = pa.Table.from_pandas(batch) # Write to online store if enabled - if getattr(self.feature_view, "online", False): + if getattr(feature_view, "online", False): # TODO: Implement proper online store writing with correct data format conversion logger.debug( f"Online store writing not implemented yet for {len(batch)} rows" ) # Write to offline store if enabled - if getattr(self.feature_view, "offline", False): + if getattr(feature_view, "offline", False): try: - context.offline_store.offline_write_batch( - config=context.repo_config, - feature_view=self.feature_view, + offline_store.offline_write_batch( + config=repo_config, + feature_view=feature_view, table=arrow_table, progress=lambda x: None, ) - logger.debug(f"Wrote {len(batch)} rows to offline store") except Exception as e: logger.error(f"Failed to write to offline store: {e}") + raise return batch except Exception as e: logger.error(f"Write operation failed: {e}") - return batch - - # Apply write operation to all batches - written_dataset = dataset.map_batches(write_batch, batch_format="pandas") + raise - # Materialize the dataset to ensure writes are executed + written_dataset = dataset.map_batches( + write_batch_with_serialized_artifacts, batch_format="pandas" + ) written_dataset = written_dataset.materialize() return DAGValue( @@ -656,5 +725,8 @@ def write_batch(batch: pd.DataFrame) -> pd.DataFrame: "feature_view": self.feature_view.name, "online": getattr(self.feature_view, "online", False), "offline": getattr(self.feature_view, "offline", False), + "batch_source_path": getattr( + getattr(self.feature_view, "batch_source", None), "path", "unknown" + ), }, ) diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py index 1e0ef944469..f16766b7d17 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +import dill import fsspec import numpy as np import pandas as pd @@ -21,6 +22,7 @@ ) from feast.feature_logging import LoggingConfig, LoggingSource from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView +from feast.feature_view_utils import resolve_feature_view_source_with_fallback from feast.infra.offline_stores.file_source import ( FileLoggingDestination, FileSource, @@ -47,7 +49,7 @@ from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage, ValidationReference from feast.type_map import feast_value_type_to_pandas_type, pa_to_feast_value_type -from feast.utils import _get_column_names, make_df_tzaware +from feast.utils import _get_column_names, make_df_tzaware, make_tzaware from feast.value_type import ValueType logger = logging.getLogger(__name__) @@ -335,12 +337,12 @@ def _convert_array_column(series: pd.Series, value_type: ValueType) -> pd.Series target_dtype = base_type_map.get(value_type, object) - def convert_array_item(item): + def convert_array_item(item) -> Union[np.ndarray, Any]: if item is None or (isinstance(item, list) and len(item) == 0): if target_dtype == object: - return np.array([], dtype=object) + return np.empty(0, dtype=object) else: - return np.array([], dtype=target_dtype) + return np.empty(0, dtype=target_dtype) else: return item @@ -378,7 +380,7 @@ class RayResourceManager: Manages Ray cluster resources for optimal performance. """ - def __init__(self, config: Optional[RayOfflineStoreConfig] = None): + def __init__(self, config: Optional[RayOfflineStoreConfig] = None) -> None: """ Initialize the resource manager with cluster resource information. """ @@ -473,7 +475,7 @@ class RayDataProcessor: Optimized data processing with Ray for feature store operations. """ - def __init__(self, resource_manager: RayResourceManager): + def __init__(self, resource_manager: RayResourceManager) -> None: """ Initialize the data processor with a resource manager. """ @@ -977,7 +979,9 @@ def _create_metadata(self) -> RetrievalMetadata: max_event_timestamp=max_timestamp, ) - def _set_metadata_info(self, feature_refs: List[str], entity_df: pd.DataFrame): + def _set_metadata_info( + self, feature_refs: List[str], entity_df: pd.DataFrame + ) -> None: """Set the feature references and entity DataFrame for metadata creation.""" self._feature_refs = feature_refs self._entity_df = entity_df @@ -1181,16 +1185,15 @@ def schema(self) -> pa.Schema: class RayOfflineStore(OfflineStore): - def __init__(self): + def __init__(self) -> None: self._staging_location: Optional[str] = None self._ray_initialized: bool = False self._resource_manager: Optional[RayResourceManager] = None self._data_processor: Optional[RayDataProcessor] = None @staticmethod - def _suppress_ray_logging(): + def _suppress_ray_logging() -> None: """Suppress Ray and Ray Data logging completely.""" - import logging import warnings # Suppress Ray warnings @@ -1233,7 +1236,7 @@ def _suppress_ray_logging(): pass # Ignore if Ray Data is not available @staticmethod - def _ensure_ray_initialized(config: Optional[RepoConfig] = None): + def _ensure_ray_initialized(config: Optional[RepoConfig] = None) -> None: """Ensure Ray is initialized with proper configuration.""" ray_config = None if config and hasattr(config, "offline_store"): @@ -1304,7 +1307,7 @@ def _ensure_ray_initialized(config: Optional[RepoConfig] = None): f"{cluster_resources.get('memory', 0) / (1024**3):.1f}GB memory" ) - def _init_ray(self, config: RepoConfig): + def _init_ray(self, config: RepoConfig) -> None: ray_config = config.offline_store assert isinstance(ray_config, RayOfflineStoreConfig) RayOfflineStore._ensure_ray_initialized(config) @@ -1852,50 +1855,35 @@ def _create_filtered_dataset( except Exception as e: raise ValueError(f"Failed to get dataset schema: {e}") - if start_date or end_date: - try: - if start_date and end_date: + def normalize(dt): + return make_tzaware(dt) if dt and dt.tzinfo is None else dt - def filter_func(row): - try: - ts = row[timestamp_field] - return start_date <= ts <= end_date - except KeyError: - raise KeyError( - f"Timestamp field '{timestamp_field}' not found in row. Available keys: {list(row.keys())}" - ) + start_date = normalize(start_date) + end_date = normalize(end_date) - filtered_ds = ds.filter(filter_func) - elif start_date: + try: + if start_date and end_date: - def filter_func(row): - try: - ts = row[timestamp_field] - return ts >= start_date - except KeyError: - raise KeyError( - f"Timestamp field '{timestamp_field}' not found in row. Available keys: {list(row.keys())}" - ) + def filter_by_timestamp_range(batch): + return (batch[timestamp_field] >= start_date) & ( + batch[timestamp_field] <= end_date + ) - filtered_ds = ds.filter(filter_func) - elif end_date: + ds = ds.filter(filter_by_timestamp_range) + elif start_date: - def filter_func(row): - try: - ts = row[timestamp_field] - return ts <= end_date - except KeyError: - raise KeyError( - f"Timestamp field '{timestamp_field}' not found in row. Available keys: {list(row.keys())}" - ) + def filter_by_start_date(batch): + return batch[timestamp_field] >= start_date - filtered_ds = ds.filter(filter_func) - else: - return ds + ds = ds.filter(filter_by_start_date) + elif end_date: - return filtered_ds - except Exception as e: - raise RuntimeError(f"Failed to filter by timestamp: {e}") + def filter_by_end_date(batch): + return batch[timestamp_field] <= end_date + + ds = ds.filter(filter_by_end_date) + except Exception as e: + raise RuntimeError(f"Failed to filter dataset by timestamp: {e}") return ds @@ -1989,9 +1977,50 @@ def get_historical_features( f"(available: {available_feature_names})" ) - source_path = store._get_source_path(fv.batch_source, config) + source_info = resolve_feature_view_source_with_fallback( + fv, config, is_materialization=False + ) + + # Read from the resolved data source + source_path = store._get_source_path(source_info.data_source, config) feature_ds = ray.data.read_parquet(source_path) - feature_size = feature_ds.size_bytes() + logger.info( + f"Reading feature view {fv.name}: {source_info.source_description}" + ) + + # Apply transformation if available + if source_info.has_transformation and source_info.transformation_func: + transformation_serialized = dill.dumps(source_info.transformation_func) + + def apply_transformation_with_serialized_func( + batch: pd.DataFrame, + ) -> pd.DataFrame: + if batch.empty: + return batch + try: + logger.debug( + f"Applying transformation to batch with columns: {list(batch.columns)}" + ) + transformation_func = dill.loads(transformation_serialized) + result = transformation_func(batch) + logger.debug( + f"Transformation result has columns: {list(result.columns)}" + ) + return result + except Exception as e: + logger.error(f"Transformation failed for {fv.name}: {e}") + return batch + + feature_ds = feature_ds.map_batches( + apply_transformation_with_serialized_func, batch_format="pandas" + ) + logger.info(f"Applied transformation to feature view {fv.name}") + elif source_info.has_transformation: + logger.warning( + f"Feature view {fv.name} marked as having transformation but no UDF found" + ) + + feature_size = feature_ds.size_bytes() or 0 field_mapping = getattr(fv.batch_source, "field_mapping", None) if field_mapping: diff --git a/sdk/python/tests/integration/compute_engines/ray_compute/conftest.py b/sdk/python/tests/integration/compute_engines/ray_compute/conftest.py new file mode 100644 index 00000000000..885b1555ec7 --- /dev/null +++ b/sdk/python/tests/integration/compute_engines/ray_compute/conftest.py @@ -0,0 +1,26 @@ +"""Pytest configuration and fixtures for Ray compute engine tests. + +This module exposes fixtures from ray_shared_utils.py so they can be +auto-discovered by pytest. +""" + +from tests.integration.compute_engines.ray_compute.ray_shared_utils import ( + entity_df, + feature_dataset, + ray_environment, + temp_dir, +) + + +def pytest_configure(config): + """Configure pytest for Ray tests.""" + config.addinivalue_line("markers", "ray: mark test as requiring Ray compute engine") + + +__all__ = [ + "entity_df", + "feature_dataset", + "ray_environment", + "temp_dir", + "pytest_configure", +] diff --git a/sdk/python/tests/integration/compute_engines/ray_compute/ray_shared_utils.py b/sdk/python/tests/integration/compute_engines/ray_compute/ray_shared_utils.py new file mode 100644 index 00000000000..9e9aabc4f90 --- /dev/null +++ b/sdk/python/tests/integration/compute_engines/ray_compute/ray_shared_utils.py @@ -0,0 +1,177 @@ +"""Shared fixtures and utilities for Ray compute engine tests.""" + +import os +import tempfile +import time +import uuid +from datetime import timedelta +from typing import Generator + +import pandas as pd +import pytest +import ray + +from feast import Entity, FileSource +from feast.data_source import DataSource +from feast.utils import _utc_now +from tests.integration.feature_repos.repo_configuration import ( + construct_test_environment, +) + +from .repo_configuration import get_ray_compute_engine_test_config + +now = _utc_now().replace(microsecond=0, second=0, minute=0) +today = now.replace(hour=0, minute=0, second=0, microsecond=0) + + +def get_test_date_range(days_back: int = 7) -> tuple: + """Get a standard test date range (start_date, end_date) for testing.""" + end_date = now + start_date = now - timedelta(days=days_back) + return start_date, end_date + + +driver = Entity( + name="driver_id", + description="driver id", +) + + +def create_feature_dataset(ray_environment) -> DataSource: + """Create a test dataset for feature views.""" + yesterday = today - timedelta(days=1) + last_week = today - timedelta(days=7) + df = pd.DataFrame( + [ + { + "driver_id": 1001, + "event_timestamp": yesterday, + "created": now - timedelta(hours=2), + "conv_rate": 0.8, + "acc_rate": 0.5, + "avg_daily_trips": 15, + }, + { + "driver_id": 1001, + "event_timestamp": last_week, + "created": now - timedelta(hours=3), + "conv_rate": 0.75, + "acc_rate": 0.9, + "avg_daily_trips": 14, + }, + { + "driver_id": 1002, + "event_timestamp": yesterday, + "created": now - timedelta(hours=2), + "conv_rate": 0.7, + "acc_rate": 0.4, + "avg_daily_trips": 12, + }, + { + "driver_id": 1002, + "event_timestamp": yesterday - timedelta(days=1), + "created": now - timedelta(hours=2), + "conv_rate": 0.3, + "acc_rate": 0.6, + "avg_daily_trips": 12, + }, + ] + ) + ds = ray_environment.data_source_creator.create_data_source( + df, + ray_environment.feature_store.project, + timestamp_field="event_timestamp", + created_timestamp_column="created", + ) + return ds + + +def create_entity_df() -> pd.DataFrame: + """Create entity dataframe for testing.""" + entity_df = pd.DataFrame( + [ + {"driver_id": 1001, "event_timestamp": today}, + {"driver_id": 1002, "event_timestamp": today}, + ] + ) + return entity_df + + +def create_unique_sink_source(temp_dir: str, base_name: str) -> FileSource: + """Create a unique sink source to avoid path collisions during parallel test execution.""" + timestamp = int(time.time() * 1000) + process_id = os.getpid() + unique_id = str(uuid.uuid4())[:8] + + # Create a unique directory for this sink - Ray needs directory paths for materialization + sink_dir = os.path.join( + temp_dir, f"{base_name}_{timestamp}_{process_id}_{unique_id}" + ) + os.makedirs(sink_dir, exist_ok=True) + + return FileSource( + name=f"{base_name}_sink_source", + path=sink_dir, # Use directory path - Ray will create files inside + timestamp_field="event_timestamp", + created_timestamp_column="created", + ) + + +def cleanup_ray_environment(ray_environment): + """Safely cleanup Ray environment and resources.""" + try: + ray_environment.teardown() + except Exception as e: + print(f"Warning: Ray environment teardown failed: {e}") + + # Ensure Ray is shut down completely + try: + if ray.is_initialized(): + ray.shutdown() + time.sleep(0.2) # Brief pause to ensure clean shutdown + except Exception as e: + print(f"Warning: Ray shutdown failed: {e}") + + +def create_ray_environment(): + """Create Ray test environment using the standardized config.""" + ray_config = get_ray_compute_engine_test_config() + ray_environment = construct_test_environment( + ray_config, None, entity_key_serialization_version=3 + ) + ray_environment.setup() + return ray_environment + + +@pytest.fixture(scope="function") +def ray_environment() -> Generator: + """Pytest fixture to provide a Ray environment for tests with automatic cleanup.""" + try: + if ray.is_initialized(): + ray.shutdown() + time.sleep(0.2) + except Exception: + pass + + environment = create_ray_environment() + yield environment + cleanup_ray_environment(environment) + + +@pytest.fixture +def feature_dataset(ray_environment) -> DataSource: + """Fixture that provides a feature dataset for testing.""" + return create_feature_dataset(ray_environment) + + +@pytest.fixture +def entity_df() -> pd.DataFrame: + """Fixture that provides an entity dataframe for testing.""" + return create_entity_df() + + +@pytest.fixture +def temp_dir() -> Generator[str, None, None]: + """Fixture that provides a temporary directory for test artifacts.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield temp_dir diff --git a/sdk/python/tests/integration/compute_engines/ray_compute/repo_configuration.py b/sdk/python/tests/integration/compute_engines/ray_compute/repo_configuration.py index 9321ad8d6b7..6b74859022f 100644 --- a/sdk/python/tests/integration/compute_engines/ray_compute/repo_configuration.py +++ b/sdk/python/tests/integration/compute_engines/ray_compute/repo_configuration.py @@ -22,51 +22,5 @@ def get_ray_compute_engine_test_config() -> IntegrationTestRepoConfig: "use_ray_cluster": False, "max_workers": 1, "enable_optimization": True, - "broadcast_join_threshold_mb": 25, - "target_partition_size_mb": 16, - "window_size_for_joins": "1H", - "ray_conf": { - "num_cpus": 1, - "object_store_memory": 80 * 1024 * 1024, - "_memory": 400 * 1024 * 1024, - }, }, ) - - -# Configuration for different test scenarios -COMPUTE_ENGINE_CONFIGS = { - "local": { - "type": "ray.engine", - "use_ray_cluster": False, - "max_workers": 1, - "enable_optimization": True, - "ray_conf": { - "num_cpus": 1, - "object_store_memory": 80 * 1024 * 1024, - "_memory": 400 * 1024 * 1024, - }, - }, - "cluster": { - "type": "ray.engine", - "use_ray_cluster": True, - "ray_address": "ray://localhost:10001", - "max_workers": 2, - "enable_optimization": True, - }, - "optimized": { - "type": "ray.engine", - "use_ray_cluster": False, - "max_workers": 2, - "enable_optimization": True, - "broadcast_join_threshold_mb": 25, - "enable_distributed_joins": True, - "max_parallelism_multiplier": 1, - "target_partition_size_mb": 16, - "ray_conf": { - "num_cpus": 1, - "object_store_memory": 80 * 1024 * 1024, - "_memory": 400 * 1024 * 1024, - }, - }, -} diff --git a/sdk/python/tests/integration/compute_engines/ray_compute/test_compute.py b/sdk/python/tests/integration/compute_engines/ray_compute/test_compute.py index 8f42dad7d57..73cc6d19cd1 100644 --- a/sdk/python/tests/integration/compute_engines/ray_compute/test_compute.py +++ b/sdk/python/tests/integration/compute_engines/ray_compute/test_compute.py @@ -1,15 +1,13 @@ -from datetime import datetime, timedelta +from datetime import timedelta from typing import cast from unittest.mock import MagicMock import pandas as pd import pytest -import ray from tqdm import tqdm -from feast import BatchFeatureView, Entity, Field +from feast import BatchFeatureView, Field from feast.aggregation import Aggregation -from feast.data_source import DataSource from feast.infra.common.materialization_job import ( MaterializationJobStatus, MaterializationTask, @@ -21,114 +19,21 @@ from feast.infra.offline_stores.contrib.ray_offline_store.ray import ( RayOfflineStore, ) -from feast.infra.offline_stores.contrib.ray_repo_configuration import ( - RayDataSourceCreator, -) from feast.types import Float32, Int32, Int64 -from tests.integration.feature_repos.integration_test_repo_config import ( - IntegrationTestRepoConfig, -) -from tests.integration.feature_repos.repo_configuration import ( - construct_test_environment, +from tests.integration.compute_engines.ray_compute.ray_shared_utils import ( + driver, + now, ) -from tests.integration.feature_repos.universal.online_store.redis import ( - RedisOnlineStoreCreator, -) - -now = datetime.now() -today = datetime.today() - -driver = Entity( - name="driver_id", - description="driver id", -) - - -def create_feature_dataset(ray_environment) -> DataSource: - yesterday = today - timedelta(days=1) - last_week = today - timedelta(days=7) - df = pd.DataFrame( - [ - { - "driver_id": 1001, - "event_timestamp": yesterday, - "created": now - timedelta(hours=2), - "conv_rate": 0.8, - "acc_rate": 0.5, - "avg_daily_trips": 15, - }, - { - "driver_id": 1001, - "event_timestamp": last_week, - "created": now - timedelta(hours=3), - "conv_rate": 0.75, - "acc_rate": 0.9, - "avg_daily_trips": 14, - }, - { - "driver_id": 1002, - "event_timestamp": yesterday, - "created": now - timedelta(hours=2), - "conv_rate": 0.7, - "acc_rate": 0.4, - "avg_daily_trips": 12, - }, - { - "driver_id": 1002, - "event_timestamp": yesterday - timedelta(days=1), - "created": now - timedelta(hours=2), - "conv_rate": 0.3, - "acc_rate": 0.6, - "avg_daily_trips": 12, - }, - ] - ) - ds = ray_environment.data_source_creator.create_data_source( - df, - ray_environment.feature_store.project, - timestamp_field="event_timestamp", - created_timestamp_column="created", - ) - return ds - - -def create_entity_df() -> pd.DataFrame: - entity_df = pd.DataFrame( - [ - {"driver_id": 1001, "event_timestamp": today}, - {"driver_id": 1002, "event_timestamp": today}, - ] - ) - return entity_df - - -def create_ray_environment(): - ray_config = IntegrationTestRepoConfig( - provider="local", - online_store_creator=RedisOnlineStoreCreator, - offline_store_creator=RayDataSourceCreator, - batch_engine={ - "type": "ray.engine", - "use_ray_cluster": False, - "max_workers": 2, - "enable_optimization": True, - }, - ) - ray_environment = construct_test_environment( - ray_config, None, entity_key_serialization_version=3 - ) - ray_environment.setup() - return ray_environment @pytest.mark.integration @pytest.mark.xdist_group(name="ray") -def test_ray_compute_engine_get_historical_features(): +def test_ray_compute_engine_get_historical_features( + ray_environment, feature_dataset, entity_df +): """Test Ray compute engine historical feature retrieval.""" - ray_environment = create_ray_environment() fs = ray_environment.feature_store registry = fs.registry - data_source = create_feature_dataset(ray_environment) def transform_feature(df: pd.DataFrame) -> pd.DataFrame: df["sum_conv_rate"] = df["sum_conv_rate"] * 2 @@ -154,55 +59,42 @@ def transform_feature(df: pd.DataFrame) -> pd.DataFrame: ], online=False, offline=False, - source=data_source, + source=feature_dataset, ) - entity_df = create_entity_df() - - try: - fs.apply([driver, driver_stats_fv]) - - # Build retrieval task - task = HistoricalRetrievalTask( - project=ray_environment.project, - entity_df=entity_df, - feature_view=driver_stats_fv, - full_feature_name=False, - registry=registry, - ) - - # Run RayComputeEngine - engine = RayComputeEngine( - repo_config=ray_environment.config, - offline_store=RayOfflineStore(), - online_store=MagicMock(), - ) + fs.apply([driver, driver_stats_fv]) - ray_dag_retrieval_job = engine.get_historical_features(registry, task) - ray_dataset = cast(RayDAGRetrievalJob, ray_dag_retrieval_job).to_ray_dataset() - df_out = ray_dataset.to_pandas().sort_values("driver_id") + # Build retrieval task + task = HistoricalRetrievalTask( + project=ray_environment.project, + entity_df=entity_df, + feature_view=driver_stats_fv, + full_feature_name=False, + registry=registry, + ) + engine = RayComputeEngine( + repo_config=ray_environment.config, + offline_store=RayOfflineStore(), + online_store=MagicMock(), + ) - # Assert output - assert df_out.driver_id.to_list() == [1001, 1002] - assert abs(df_out["sum_conv_rate"].to_list()[0] - 1.6) < 1e-6 - assert abs(df_out["sum_conv_rate"].to_list()[1] - 2.0) < 1e-6 - assert abs(df_out["avg_acc_rate"].to_list()[0] - 1.0) < 1e-6 - assert abs(df_out["avg_acc_rate"].to_list()[1] - 1.0) < 1e-6 + ray_dag_retrieval_job = engine.get_historical_features(registry, task) + ray_dataset = cast(RayDAGRetrievalJob, ray_dag_retrieval_job).to_ray_dataset() + df_out = ray_dataset.to_pandas().sort_values("driver_id") - finally: - ray_environment.teardown() - if ray.is_initialized(): - ray.shutdown() + assert df_out.driver_id.to_list() == [1001, 1002] + assert abs(df_out["sum_conv_rate"].to_list()[0] - 1.6) < 1e-6 + assert abs(df_out["sum_conv_rate"].to_list()[1] - 2.0) < 1e-6 + assert abs(df_out["avg_acc_rate"].to_list()[0] - 1.0) < 1e-6 + assert abs(df_out["avg_acc_rate"].to_list()[1] - 1.0) < 1e-6 @pytest.mark.integration @pytest.mark.xdist_group(name="ray") -def test_ray_compute_engine_materialize(): +def test_ray_compute_engine_materialize(ray_environment, feature_dataset): """Test Ray compute engine materialization.""" - ray_environment = create_ray_environment() fs = ray_environment.feature_store registry = fs.registry - data_source = create_feature_dataset(ray_environment) def transform_feature(df: pd.DataFrame) -> pd.DataFrame: df["sum_conv_rate"] = df["sum_conv_rate"] * 2 @@ -228,42 +120,32 @@ def transform_feature(df: pd.DataFrame) -> pd.DataFrame: ], online=True, offline=False, - source=data_source, + source=feature_dataset, ) def tqdm_builder(length): return tqdm(length, ncols=100) - try: - fs.apply([driver, driver_stats_fv]) - - # Build materialization task - task = MaterializationTask( - project=ray_environment.project, - feature_view=driver_stats_fv, - start_time=now - timedelta(days=2), - end_time=now, - tqdm_builder=tqdm_builder, - ) + fs.apply([driver, driver_stats_fv]) - # Run RayComputeEngine - engine = RayComputeEngine( - repo_config=ray_environment.config, - offline_store=RayOfflineStore(), - online_store=MagicMock(), - ) - - ray_materialize_jobs = engine.materialize(registry, task) + task = MaterializationTask( + project=ray_environment.project, + feature_view=driver_stats_fv, + start_time=now - timedelta(days=2), + end_time=now, + tqdm_builder=tqdm_builder, + ) - assert len(ray_materialize_jobs) == 1 - assert ray_materialize_jobs[0].status() == MaterializationJobStatus.SUCCEEDED + engine = RayComputeEngine( + repo_config=ray_environment.config, + offline_store=RayOfflineStore(), + online_store=MagicMock(), + ) - # Additional assertions can be added here for online store checks + ray_materialize_jobs = engine.materialize(registry, task) - finally: - ray_environment.teardown() - if ray.is_initialized(): - ray.shutdown() + assert len(ray_materialize_jobs) == 1 + assert ray_materialize_jobs[0].status() == MaterializationJobStatus.SUCCEEDED @pytest.mark.integration diff --git a/sdk/python/tests/integration/compute_engines/ray_compute/test_source_feature_views.py b/sdk/python/tests/integration/compute_engines/ray_compute/test_source_feature_views.py new file mode 100644 index 00000000000..7d8f23e1bf6 --- /dev/null +++ b/sdk/python/tests/integration/compute_engines/ray_compute/test_source_feature_views.py @@ -0,0 +1,308 @@ +import time +from datetime import timedelta + +import pandas as pd +import pytest + +from feast import FeatureView, Field +from feast.data_source import DataSource +from feast.infra.common.materialization_job import ( + MaterializationJobStatus, +) +from feast.types import Float32, Int32, Int64 +from tests.integration.compute_engines.ray_compute.ray_shared_utils import ( + create_entity_df, + create_feature_dataset, + create_unique_sink_source, + driver, + now, + today, +) + + +def create_base_feature_view(source: DataSource, name_suffix: str = "") -> FeatureView: + """Create a base feature view with data source.""" + return FeatureView( + name=f"base_driver_stats{name_suffix}", + entities=[driver], + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + Field(name="driver_id", dtype=Int32), + ], + online=True, + offline=True, + source=source, + ) + + +def create_derived_feature_view( + base_fv: FeatureView, sink_source: DataSource, name_suffix: str = "" +) -> FeatureView: + """Create a derived feature view that uses another feature view as source. + Note: This creates a regular FeatureView with another FeatureView as source. + """ + return FeatureView( + name=f"derived_driver_stats{name_suffix}", + entities=[driver], + schema=[ + Field(name="conv_rate", dtype=Float32), # Same feature names as source + Field(name="acc_rate", dtype=Float32), # Same feature names as source + Field(name="avg_daily_trips", dtype=Int64), # Same feature names as source + Field(name="driver_id", dtype=Int32), + ], + online=True, + offline=True, + source=base_fv, + sink_source=sink_source, + ) + + +@pytest.mark.integration +@pytest.mark.xdist_group(name="ray") +def test_ray_compute_engine_single_source_feature_view(ray_environment, temp_dir): + """Test Ray compute engine with a single source feature view.""" + fs = ray_environment.feature_store + + data_source = create_feature_dataset(ray_environment) + base_fv = create_base_feature_view(data_source, "_single") + sink_source = create_unique_sink_source(temp_dir, "derived_sink_single") + derived_fv = create_derived_feature_view(base_fv, sink_source, "_single") + fs.apply([driver, base_fv, derived_fv]) + + entity_df = create_entity_df() + job = fs.get_historical_features( + entity_df=entity_df, + features=[ + f"{base_fv.name}:conv_rate", + f"{base_fv.name}:acc_rate", + f"{derived_fv.name}:conv_rate", + f"{derived_fv.name}:acc_rate", + ], + full_feature_names=True, + ) + result_df = job.to_df() + assert len(result_df) == 2 + assert f"{base_fv.name}__conv_rate" in result_df.columns + assert f"{base_fv.name}__acc_rate" in result_df.columns + assert f"{derived_fv.name}__conv_rate" in result_df.columns + assert f"{derived_fv.name}__acc_rate" in result_df.columns + + +@pytest.mark.integration +@pytest.mark.xdist_group(name="ray") +def test_ray_compute_engine_materialization_with_source_feature_views( + ray_environment, temp_dir +): + """Test Ray compute engine materialization with source feature views.""" + fs = ray_environment.feature_store + data_source = create_feature_dataset(ray_environment) + base_fv = create_base_feature_view(data_source, "_materialize") + + sink_source = create_unique_sink_source(temp_dir, "derived_sink") + derived_fv = create_derived_feature_view(base_fv, sink_source, "_materialize") + + fs.apply([driver, base_fv, derived_fv]) + start_date = today - timedelta(days=7) + end_date = today + + # Materialize only the derived feature view - compute engine handles base dependencies + derived_job = fs.materialize( + start_date=start_date, + end_date=end_date, + feature_views=[derived_fv.name], + ) + + if derived_job is not None: + assert derived_job.status == MaterializationJobStatus.SUCCEEDED + else: + print("Materialization completed synchronously (no job object returned)") + + +@pytest.mark.integration +@pytest.mark.xdist_group(name="ray") +def test_ray_compute_engine_cycle_detection(ray_environment, temp_dir): + """Test Ray compute engine cycle detection in feature view dependencies.""" + fs = ray_environment.feature_store + data_source = create_feature_dataset(ray_environment) + sink_source1 = create_unique_sink_source(temp_dir, "cycle_sink1") + sink_source2 = create_unique_sink_source(temp_dir, "cycle_sink2") + + fv1 = FeatureView( + name="cycle_fv1", + entities=[driver], + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="driver_id", dtype=Int32), + ], + source=data_source, + online=True, + offline=True, + ) + + fv2 = FeatureView( + name="cycle_fv2", + entities=[driver], + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="driver_id", dtype=Int32), + ], + source=fv1, + sink_source=sink_source1, + online=True, + offline=True, + ) + + fv3 = FeatureView( + name="cycle_fv3", + entities=[driver], + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="driver_id", dtype=Int32), + ], + source=fv2, + sink_source=sink_source2, + online=True, + offline=True, + ) + + # Apply feature views (this should work without cycles) + fs.apply([driver, fv1, fv2, fv3]) + + entity_df = create_entity_df() + + job = fs.get_historical_features( + entity_df=entity_df, + features=[ + f"{fv1.name}:conv_rate", + f"{fv2.name}:conv_rate", + f"{fv3.name}:conv_rate", + ], + full_feature_names=True, + ) + + result_df = job.to_df() + + assert len(result_df) == 2 + assert f"{fv1.name}__conv_rate" in result_df.columns + assert f"{fv2.name}__conv_rate" in result_df.columns + assert f"{fv3.name}__conv_rate" in result_df.columns + + +@pytest.mark.integration +@pytest.mark.xdist_group(name="ray") +def test_ray_compute_engine_error_handling(ray_environment, temp_dir): + """Test Ray compute engine error handling with invalid source feature views.""" + fs = ray_environment.feature_store + data_source = create_feature_dataset(ray_environment) + base_fv = create_base_feature_view(data_source, "_error") + + # Test 1: Regular FeatureView with FeatureView source but no sink_source should fail + with pytest.raises( + ValueError, match="Derived FeatureView must specify `sink_source`" + ): + FeatureView( + name="invalid_fv", + entities=[driver], + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="driver_id", dtype=Int32), + ], + source=base_fv, + online=True, + offline=True, + ) + + # Test 2: Valid FeatureView with sink_source should work + sink_source = create_unique_sink_source(temp_dir, "valid_sink") + valid_fv = FeatureView( + name="valid_fv", + entities=[driver], + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="driver_id", dtype=Int32), + ], # Use same feature name as source + source=base_fv, + sink_source=sink_source, + online=True, + offline=True, + ) + + fs.apply([driver, base_fv, valid_fv]) + entity_df = create_entity_df() + job = fs.get_historical_features( + entity_df=entity_df, + features=[ + f"{base_fv.name}:conv_rate", + f"{valid_fv.name}:conv_rate", # Use same feature name as source + ], + full_feature_names=True, + ) + + result_df = job.to_df() + assert len(result_df) == 2 + assert f"{base_fv.name}__conv_rate" in result_df.columns + assert f"{valid_fv.name}__conv_rate" in result_df.columns + assert result_df[f"{base_fv.name}__conv_rate"].notna().all() + assert result_df[f"{valid_fv.name}__conv_rate"].notna().all() + + +@pytest.mark.integration +@pytest.mark.xdist_group(name="ray") +def test_ray_compute_engine_performance_with_source_feature_views( + ray_environment, temp_dir +): + """Test Ray compute engine performance with source feature views.""" + fs = ray_environment.feature_store + large_df = pd.DataFrame() + for i in range(1000): + large_df = pd.concat( + [ + large_df, + pd.DataFrame( + { + "driver_id": [1000 + i], + "event_timestamp": [today - timedelta(days=i % 30)], + "created": [now - timedelta(hours=i % 24)], + "conv_rate": [0.5 + (i % 10) * 0.05], + "acc_rate": [0.6 + (i % 10) * 0.04], + "avg_daily_trips": [10 + i % 20], + } + ), + ] + ) + data_source = ray_environment.data_source_creator.create_data_source( + large_df, + fs.project, + timestamp_field="event_timestamp", + created_timestamp_column="created", + ) + base_fv = create_base_feature_view(data_source, "_perf") + sink_source1 = create_unique_sink_source(temp_dir, "perf_sink1") + derived_fv1 = create_derived_feature_view(base_fv, sink_source1, "_perf1") + sink_source2 = create_unique_sink_source(temp_dir, "perf_sink2") + derived_fv2 = create_derived_feature_view(derived_fv1, sink_source2, "_perf2") + fs.apply([driver, base_fv, derived_fv1, derived_fv2]) + + large_entity_df = pd.DataFrame( + { + "driver_id": [1000 + i for i in range(100)], + "event_timestamp": [today] * 100, + } + ) + start_time = time.time() + job = fs.get_historical_features( + entity_df=large_entity_df, + features=[ + f"{base_fv.name}:conv_rate", + f"{derived_fv1.name}:conv_rate", + ], + full_feature_names=True, + ) + result_df = job.to_df() + end_time = time.time() + assert len(result_df) == 100 + assert f"{base_fv.name}__conv_rate" in result_df.columns + assert f"{derived_fv1.name}__conv_rate" in result_df.columns + assert end_time - start_time < 60 diff --git a/sdk/python/tests/unit/infra/compute_engines/ray_compute/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/ray_compute/test_nodes.py index 0cd3f7ca4a0..c6cfc13280d 100644 --- a/sdk/python/tests/unit/infra/compute_engines/ray_compute/test_nodes.py +++ b/sdk/python/tests/unit/infra/compute_engines/ray_compute/test_nodes.py @@ -17,7 +17,6 @@ RayJoinNode, RayReadNode, RayTransformationNode, - RayWriteNode, ) @@ -293,30 +292,6 @@ def test_ray_dedup_node( assert "driver_id" in result_df.columns -def test_ray_write_node( - ray_session, ray_config, mock_context, sample_data, column_info -): - """Test RayWriteNode functionality.""" - ray_dataset = ray.data.from_pandas(sample_data) - input_value = DAGValue(data=ray_dataset, format=DAGFormat.RAY) - dummy_node = DummyInputNode("input_node", input_value) - - mock_feature_view = DummyFeatureView() - node = RayWriteNode( - name="write", - feature_view=mock_feature_view, - config=ray_config, - ) - node.add_input(dummy_node) - mock_context.node_outputs = {"input_node": input_value} - result = node.execute(mock_context) - assert isinstance(result, DAGValue) - assert result.format == DAGFormat.RAY - result_df = result.data.to_pandas() - assert len(result_df) == 3 - assert "driver_id" in result_df.columns - - def test_ray_config_validation(): """Test Ray configuration validation.""" # Test valid configuration From c3a87ab03a38df7aac9a6c85e83fbde82d5ad256 Mon Sep 17 00:00:00 2001 From: ntkathole Date: Sun, 27 Jul 2025 11:39:30 +0530 Subject: [PATCH 09/10] fix: Remove redundant source resolution Signed-off-by: ntkathole --- docs/reference/compute-engine/ray.md | 10 --- docs/reference/offline-stores/ray.md | 38 +++++----- pyproject.toml | 2 +- .../compute_engines/ray/feature_builder.py | 49 +++++------- .../feast/infra/compute_engines/ray/nodes.py | 76 ++++++++----------- setup.py | 1 - 6 files changed, 71 insertions(+), 105 deletions(-) diff --git a/docs/reference/compute-engine/ray.md b/docs/reference/compute-engine/ray.md index a286867cd5f..01ff9c0dd34 100644 --- a/docs/reference/compute-engine/ray.md +++ b/docs/reference/compute-engine/ray.md @@ -203,16 +203,6 @@ batch_engine: # Ray cluster configuration (inherits from offline_store if not specified) ray_address: localhost:10001 # Ray cluster address use_ray_cluster: true # Use Ray cluster mode - -# Optional: Online store configuration -online_store: - type: sqlite - path: data/online_store.db - -# Optional: Feature server configuration -feature_server: - port: 6566 - metrics_port: 8888 ``` ## DAG Node Types diff --git a/docs/reference/offline-stores/ray.md b/docs/reference/offline-stores/ray.md index fc7baed5965..65486312096 100644 --- a/docs/reference/offline-stores/ray.md +++ b/docs/reference/offline-stores/ray.md @@ -187,17 +187,7 @@ batch_engine: #### Ray Compute Engine Options -| Option | Type | Default | Description | -|--------|------|---------|-------------| -| `type` | string | Required | Must be `ray.engine` | -| `max_workers` | int | CPU count | Maximum number of Ray workers | -| `enable_optimization` | boolean | true | Enable performance optimizations | -| `broadcast_join_threshold_mb` | int | 100 | Size threshold for broadcast joins (MB) | -| `max_parallelism_multiplier` | int | 2 | Parallelism as multiple of CPU cores | -| `target_partition_size_mb` | int | 64 | Target partition size (MB) | -| `window_size_for_joins` | string | "1H" | Time window for distributed joins | -| `enable_distributed_joins` | boolean | true | Enable distributed joins for large datasets | -| `staging_location` | string | None | Remote path for batch materialization jobs | +For Ray compute engine configuration options, see the [Ray Compute Engine documentation](../compute-engine/ray.md#configuration-options). ## Resource Management and Testing @@ -448,20 +438,12 @@ The Ray offline store has the following limitations: 1. **File Sources Only**: Currently supports only `FileSource` data sources 2. **No Direct SQL**: Does not support SQL query interfaces 3. **No Online Writes**: Cannot write directly to online stores -4. **Limited Transformations**: Complex feature transformations should use the Ray Compute Engine +4. **No Complex Transformations**: The Ray offline store focuses on data I/O operations. For complex feature transformations (aggregations, joins, custom UDFs), use the [Ray Compute Engine](../compute-engine/ray.md) instead ## Integration with Ray Compute Engine For complex feature processing operations, use the Ray offline store in combination with the [Ray Compute Engine](../compute-engine/ray.md). See the **Ray Offline Store + Compute Engine** configuration example in the [Configuration](#configuration) section above for a complete setup. -The Ray offline store provides the data I/O foundation, while the Ray compute engine handles: -- **Point-in-time joins**: Efficient temporal joins for historical feature retrieval -- **Feature aggregations**: Distributed aggregations across time windows -- **Complex transformations**: Advanced feature transformations and computations -- **Historical feature retrieval**: `get_historical_features()` with distributed processing -- **Distributed processing optimization**: Automatic join strategy selection and resource management -- **Materialization**: Distributed batch materialization with progress tracking - For more advanced troubleshooting, refer to the [Ray documentation](https://docs.ray.io/en/latest/data/getting-started.html). @@ -507,6 +489,22 @@ features = store.get_historical_features(entity_df=df, features=["fv:feature"]) # Direct data access (uses offline store) job = RayOfflineStore.pull_latest_from_table_or_query(...) df = job.to_df() + +# Offline write batch (materialization) +# Create sample data for materialization +data = pa.table({ + "driver_id": [1, 2, 3, 4, 5], + "avg_daily_trips": [10.5, 15.2, 8.7, 12.3, 9.8], + "event_timestamp": [datetime.now()] * 5 +}) + +# Write batch to offline store +RayOfflineStore.offline_write_batch( + config=store.config, + feature_view=driver_stats_fv, + table=data, + progress=lambda rows: print(f"Processed {rows} rows") +) ``` For complete examples, see the [Configuration](#configuration) section above. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 2879bfbba26..357a20fcf17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -170,7 +170,7 @@ ci = [ "virtualenv<20.24.2", "feast[aws, azure, cassandra, clickhouse, couchbase, delta, docling, duckdb, elasticsearch, faiss, gcp, ge, go, grpcio, hazelcast, hbase, ibis, ikv, k8s, mcp, milvus, mssql, mysql, opentelemetry, spark, trino, postgres, pytorch, qdrant, rag, ray, redis, singlestore, snowflake, sqlite_vec]" ] -nlp = ["feast[docling, milvus, pytorch, rag, ray]"] +nlp = ["feast[docling, milvus, pytorch, rag]"] dev = ["feast[ci]"] docs = ["feast[ci]"] # used for the 'feature-server' container image build diff --git a/sdk/python/feast/infra/compute_engines/ray/feature_builder.py b/sdk/python/feast/infra/compute_engines/ray/feature_builder.py index 8d6003ff3c8..03a868c1779 100644 --- a/sdk/python/feast/infra/compute_engines/ray/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/ray/feature_builder.py @@ -1,7 +1,7 @@ import logging from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast -from feast.feature_view_utils import resolve_feature_view_source_with_fallback +from feast import FeatureView from feast.infra.common.materialization_job import MaterializationTask from feast.infra.common.retrieval_task import HistoricalRetrievalTask from feast.infra.compute_engines.algorithms.topo import topological_sort @@ -47,9 +47,8 @@ def __init__( def build_source_node(self, view): """Build the source node for reading feature data.""" - - source_info = resolve_feature_view_source_with_fallback( - view, config=None, is_materialization=self.is_materialization + data_source = getattr(view, "batch_source", None) or getattr( + view, "source", None ) start_time = self.task.start_time end_time = self.task.end_time @@ -57,7 +56,7 @@ def build_source_node(self, view): node = RayReadNode( name="source", - source=source_info.data_source, + source=data_source, column_info=column_info, config=self.config, start_time=start_time, @@ -173,47 +172,36 @@ def build_validation_node(self, view, input_node): return input_node def _build(self, view, input_nodes: Optional[List[DAGNode]]) -> DAGNode: - """Override _build to handle derived views during materialization.""" - # Step 1: build source node - if view.data_source: - last_node = self.build_source_node(view) + has_physical_source = (hasattr(view, "batch_source") and view.batch_source) or ( + hasattr(view, "source") + and view.source + and not isinstance(view.source, FeatureView) + ) - # If source node is None (derived view during materialization), use input nodes - if last_node is None and input_nodes: - if self._should_transform(view): - last_node = self.build_transformation_node(view, input_nodes) - else: - last_node = self.build_join_node(view, input_nodes) - elif last_node is not None: - if self._should_transform(view): - # Transform applied to the source data - last_node = self.build_transformation_node(view, [last_node]) + is_derived_view = hasattr(view, "source_views") and view.source_views - # If there are input nodes, transform or join them + if has_physical_source and not is_derived_view: + last_node = self.build_source_node(view) + if self._should_transform(view): + last_node = self.build_transformation_node(view, [last_node]) elif input_nodes: - # User-defined transform handles the merging of input views if self._should_transform(view): last_node = self.build_transformation_node(view, input_nodes) - # Default join else: last_node = self.build_join_node(view, input_nodes) else: raise ValueError(f"FeatureView {view.name} has no valid source or inputs") - # Skip subsequent steps if last_node is None if last_node is None: raise ValueError(f"Failed to build processing node for {view.name}") - # Step 2: filter last_node = self.build_filter_node(view, last_node) - # Step 3: aggregate or dedupe if self._should_aggregate(view): last_node = self.build_aggregation_node(view, last_node) elif self._should_dedupe(view): last_node = self.build_dedup_node(view, last_node) - # Step 4: validate if self._should_validate(view): last_node = self.build_validation_node(view, last_node) @@ -260,9 +248,7 @@ def _build_materialization_plan(self) -> ExecutionPlan: parent_write_nodes.append(view_to_write_node[parent.view.name]) if parent_write_nodes: - # For derived views, create a simple passthrough node that depends on parents - # This ensures the derived view processing only starts after parents are materialized - processing_node = RayDerivedReadNode( + derived_read_node = RayDerivedReadNode( name=f"{view.name}:derived_read", feature_view=view, parent_dependencies=cast(List[DAGNode], parent_write_nodes), @@ -270,7 +256,10 @@ def _build_materialization_plan(self) -> ExecutionPlan: column_info=self.get_column_info(view), is_materialization=self.is_materialization, ) - self.nodes.append(processing_node) + self.nodes.append(derived_read_node) + + # Then build the rest of the processing chain (filter, aggregate, etc.) + processing_node = self._build(view, [derived_read_node]) else: # Parent not yet built - this shouldn't happen in topo order raise ValueError(f"Parent views for {view.name} not yet built") diff --git a/sdk/python/feast/infra/compute_engines/ray/nodes.py b/sdk/python/feast/infra/compute_engines/ray/nodes.py index 2d919da7274..c7b3ad701ae 100644 --- a/sdk/python/feast/infra/compute_engines/ray/nodes.py +++ b/sdk/python/feast/infra/compute_engines/ray/nodes.py @@ -11,7 +11,7 @@ from feast import BatchFeatureView, FeatureView, StreamFeatureView from feast.aggregation import Aggregation from feast.data_source import DataSource -from feast.feature_view_utils import resolve_feature_view_source_with_fallback +from feast.feature_view_utils import get_transformation_function, has_transformation from feast.infra.common.serde import SerializedArtifacts from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.dag.model import DAGFormat @@ -597,53 +597,43 @@ def __init__( def execute(self, context: ExecutionContext) -> DAGValue: """Execute the derived read operation after parents are materialized.""" - # Wait for all parent dependencies to complete - # The inputs contain the parent write nodes which have completed materialization - self.get_input_values(context) - source_info = resolve_feature_view_source_with_fallback( - self.feature_view, config=None, is_materialization=self.is_materialization - ) - data_source = source_info.data_source - try: - retrieval_job = create_offline_store_retrieval_job( - data_source=data_source, - column_info=self.column_info, - context=context, - start_time=None, - end_time=None, + parent_values = self.get_input_values(context) + + if not parent_values: + raise ValueError( + f"No parent data available for derived view {self.feature_view.name}" ) - # Convert to Ray Dataset - if hasattr(retrieval_job, "to_ray_dataset"): - ray_dataset = retrieval_job.to_ray_dataset() - else: - try: - arrow_table = retrieval_job.to_arrow() - ray_dataset = ray.data.from_arrow(arrow_table) - except Exception: - df = retrieval_job.to_df() - ray_dataset = ray.data.from_pandas(df) + parent_value = parent_values[0] + parent_value.assert_format(DAGFormat.RAY) - # Apply field mapping if needed - field_mapping = getattr(data_source, "field_mapping", None) - if field_mapping: - ray_dataset = apply_field_mapping(ray_dataset, field_mapping) + if has_transformation(self.feature_view): + transformation_func = get_transformation_function(self.feature_view) + if callable(transformation_func): - return DAGValue( - data=ray_dataset, - format=DAGFormat.RAY, - metadata={ - "source": "derived_from_parent", - "source_description": source_info.source_description, - "data_source_path": getattr(data_source, "path", "unknown"), - }, - ) + def apply_transformation(batch: pd.DataFrame) -> pd.DataFrame: + return transformation_func(batch) - except Exception as e: - logger.error( - f"Failed to read derived view {self.feature_view.name} from parent data source {getattr(data_source, 'path', 'unknown')}: {e}" - ) - raise + transformed_dataset = parent_value.data.map_batches( + apply_transformation + ) + return DAGValue( + data=transformed_dataset, + format=DAGFormat.RAY, + metadata={ + "source": "derived_from_parent", + "source_description": f"Transformed data from parent for {self.feature_view.name}", + }, + ) + + return DAGValue( + data=parent_value.data, + format=DAGFormat.RAY, + metadata={ + "source": "derived_from_parent", + "source_description": f"Data from parent for {self.feature_view.name}", + }, + ) class RayWriteNode(DAGNode): diff --git a/setup.py b/setup.py index 2dcaea178cb..7545b0c19ae 100644 --- a/setup.py +++ b/setup.py @@ -279,7 +279,6 @@ + MILVUS_REQUIRED + TORCH_REQUIRED + RAG_REQUIRED - + RAY_REQUIRED ) DOCS_REQUIRED = CI_REQUIRED DEV_REQUIRED = CI_REQUIRED From 5435859709c651150b4f680d6f569fb2495b2fb4 Mon Sep 17 00:00:00 2001 From: ntkathole Date: Wed, 6 Aug 2025 12:05:11 +0530 Subject: [PATCH 10/10] fix: Handle comments and removed use_ray_cluster config Signed-off-by: ntkathole --- docs/reference/compute-engine/ray.md | 40 +++++------ docs/reference/offline-stores/ray.md | 11 ---- .../infra/compute_engines/ray/compute.py | 8 +-- .../feast/infra/compute_engines/ray/config.py | 3 - .../compute_engines/ray/feature_builder.py | 3 + .../feast/infra/compute_engines/ray/job.py | 1 + .../feast/infra/compute_engines/ray/nodes.py | 32 +++------ .../contrib/ray_offline_store/ray.py | 66 ++++--------------- .../contrib/ray_repo_configuration.py | 1 - sdk/python/feast/infra/ray_shared_utils.py | 1 - sdk/python/feast/type_map.py | 47 +++++++++++++ .../ray_compute/repo_configuration.py | 1 - .../ray_compute/test_compute.py | 2 - .../compute_engines/ray_compute/test_nodes.py | 2 - 14 files changed, 92 insertions(+), 126 deletions(-) diff --git a/docs/reference/compute-engine/ray.md b/docs/reference/compute-engine/ray.md index 01ff9c0dd34..4ecc449e40b 100644 --- a/docs/reference/compute-engine/ray.md +++ b/docs/reference/compute-engine/ray.md @@ -42,29 +42,31 @@ offline_store: storage_path: data/ray_storage batch_engine: type: ray.engine - max_workers: 4 # Optional: Maximum number of workers - enable_optimization: true # Optional: Enable performance optimizations + max_workers: 4 # Optional: Maximum number of workers + enable_optimization: true # Optional: Enable performance optimizations broadcast_join_threshold_mb: 100 # Optional: Broadcast join threshold (MB) max_parallelism_multiplier: 2 # Optional: Parallelism multiplier target_partition_size_mb: 64 # Optional: Target partition size (MB) window_size_for_joins: "1H" # Optional: Time window for distributed joins ray_address: localhost:10001 # Optional: Ray cluster address - use_ray_cluster: false # Optional: Use Ray cluster mode ``` ### Configuration Options | Option | Type | Default | Description | |--------|------|---------|-------------| -| `type` | string | Required | Must be `ray.engine` | -| `max_workers` | int | CPU count | Maximum number of Ray workers | +| `type` | string | `"ray.engine"` | Must be `ray.engine` | +| `max_workers` | int | None (uses all cores) | Maximum number of Ray workers | | `enable_optimization` | boolean | true | Enable performance optimizations | | `broadcast_join_threshold_mb` | int | 100 | Size threshold for broadcast joins (MB) | | `max_parallelism_multiplier` | int | 2 | Parallelism as multiple of CPU cores | | `target_partition_size_mb` | int | 64 | Target partition size (MB) | | `window_size_for_joins` | string | "1H" | Time window for distributed joins | -| `ray_address` | string | None | Ray cluster address | -| `use_ray_cluster` | boolean | false | Use Ray cluster mode | +| `ray_address` | string | None | Ray cluster address (None = local Ray) | +| `enable_distributed_joins` | boolean | true | Enable distributed joins for large datasets | +| `staging_location` | string | None | Remote path for batch materialization jobs | +| `ray_conf` | dict | None | Ray configuration parameters | +| `execution_timeout_seconds` | int | None | Timeout for job execution in seconds | ## Usage Examples @@ -159,7 +161,6 @@ batch_engine: window_size_for_joins: "30min" # Ray cluster configuration - use_ray_cluster: true ray_address: "ray://head-node:10001" ``` @@ -181,7 +182,6 @@ offline_store: type: ray storage_path: s3://my-bucket/feast-data # Optional: Path for storing datasets ray_address: localhost:10001 # Optional: Ray cluster address - use_ray_cluster: true # Optional: Use Ray cluster mode # Ray compute engine configuration # Handles complex feature computation and distributed processing @@ -202,7 +202,6 @@ batch_engine: # Ray cluster configuration (inherits from offline_store if not specified) ray_address: localhost:10001 # Ray cluster address - use_ray_cluster: true # Use Ray cluster mode ``` ## DAG Node Types @@ -258,22 +257,18 @@ The Ray compute engine automatically selects optimal join strategies: ### Broadcast Join Used for small feature datasets: -```python -# Automatically selected when feature data < 100MB -# Features are cached in Ray's object store -# Entities are distributed across cluster -# Each worker gets a copy of feature data -``` +- Automatically selected when feature data < 100MB +- Features are cached in Ray's object store +- Entities are distributed across cluster +- Each worker gets a copy of feature data ### Distributed Windowed Join Used for large feature datasets: -```python -# Automatically selected when feature data > 100MB -# Data is partitioned by time windows -# Point-in-time joins within each window -# Results are combined across windows -``` +- Automatically selected when feature data > 100MB +- Data is partitioned by time windows +- Point-in-time joins within each window +- Results are combined across windows ### Strategy Selection Logic @@ -353,7 +348,6 @@ offline_store: storage_path: s3://my-bucket/feast-data batch_engine: type: ray.engine - use_ray_cluster: true ray_address: "ray://ray-cluster:10001" broadcast_join_threshold_mb: 50 ``` diff --git a/docs/reference/offline-stores/ray.md b/docs/reference/offline-stores/ray.md index 65486312096..58f62c34ece 100644 --- a/docs/reference/offline-stores/ray.md +++ b/docs/reference/offline-stores/ray.md @@ -33,7 +33,6 @@ The Ray offline store provides: | export to arrow table | Yes | | persist results in offline store| Yes | | local execution of ODFVs | Yes | -| remote execution of ODFVs | No | | preview query plan | Yes | | read partitioned data | Yes | @@ -74,7 +73,6 @@ offline_store: type: ray storage_path: data/ray_storage # Optional: Path for storing datasets ray_address: localhost:10001 # Optional: Ray cluster address - use_ray_cluster: false # Optional: Whether to use Ray cluster ``` ### Ray Offline Store + Compute Engine @@ -91,7 +89,6 @@ offline_store: type: ray storage_path: s3://my-bucket/feast-data # Optional: Path for storing datasets ray_address: localhost:10001 # Optional: Ray cluster address - use_ray_cluster: true # Optional: Use Ray cluster mode # Ray compute engine for distributed feature processing batch_engine: @@ -112,7 +109,6 @@ batch_engine: # Ray cluster configuration (optional) ray_address: localhost:10001 # Ray cluster address - use_ray_cluster: true # Use Ray cluster mode staging_location: s3://my-bucket/staging # Remote staging location ``` @@ -158,7 +154,6 @@ offline_store: type: ray storage_path: s3://my-production-bucket/feast-data ray_address: "ray://production-head-node:10001" - use_ray_cluster: true batch_engine: type: ray.engine @@ -169,7 +164,6 @@ batch_engine: target_partition_size_mb: 128 window_size_for_joins: "30min" ray_address: "ray://production-head-node:10001" - use_ray_cluster: true staging_location: s3://my-production-bucket/staging ``` @@ -182,7 +176,6 @@ batch_engine: | `type` | string | Required | Must be `feast.offline_stores.contrib.ray_offline_store.ray.RayOfflineStore` or `ray` | | `storage_path` | string | None | Path for storing temporary files and datasets | | `ray_address` | string | None | Address of the Ray cluster (e.g., "localhost:10001") | -| `use_ray_cluster` | boolean | false | Whether to use Ray cluster mode | | `ray_conf` | dict | None | Ray initialization parameters for resource management (e.g., memory, CPU limits) | #### Ray Compute Engine Options @@ -224,7 +217,6 @@ offline_store: type: ray storage_path: s3://my-bucket/feast-data ray_address: "ray://production-cluster:10001" - use_ray_cluster: true # Optimized for production workloads broadcast_join_threshold_mb: 100 max_parallelism_multiplier: 2 @@ -259,7 +251,6 @@ offline_store: offline_store: type: ray ray_address: "ray://cluster-head:10001" - use_ray_cluster: true broadcast_join_threshold_mb: 200 max_parallelism_multiplier: 4 ``` @@ -406,7 +397,6 @@ ray start --head --port=10001 offline_store: type: ray ray_address: localhost:10001 - use_ray_cluster: true storage_path: s3://my-bucket/features ``` @@ -468,7 +458,6 @@ offline_store: offline_store: type: ray storage_path: s3://my-bucket/feast-data - use_ray_cluster: true batch_engine: type: ray.engine diff --git a/sdk/python/feast/infra/compute_engines/ray/compute.py b/sdk/python/feast/infra/compute_engines/ray/compute.py index 0cd7cddccfd..7bf7e15dfb0 100644 --- a/sdk/python/feast/infra/compute_engines/ray/compute.py +++ b/sdk/python/feast/infra/compute_engines/ray/compute.py @@ -58,7 +58,7 @@ def __init__( def _ensure_ray_initialized(self): """Ensure Ray is initialized with proper configuration.""" if not ray.is_initialized(): - if self.config.use_ray_cluster and self.config.ray_address: + if self.config.ray_address: ray.init( address=self.config.ray_address, ignore_reinit_error=True, @@ -206,7 +206,7 @@ def _materialize_from_offline_store( if getattr(feature_view, "online", False): # TODO: Implement proper online store writing with correct data format conversion logger.debug( - f"Online store writing not implemented yet for {arrow_table.num_rows} rows" + "Online store writing not implemented yet for Ray compute engine" ) # Write to offline store if enabled (this handles sink_source automatically for derived views) @@ -228,9 +228,7 @@ def _materialize_from_offline_store( # Write to sink_source using Ray data try: - # Convert arrow table to pandas then to ray dataset - df = arrow_table.to_pandas() - ray_dataset = ray.data.from_pandas(df) + ray_dataset = ray.data.from_arrow(arrow_table) ray_dataset.write_parquet(sink_source.path) except Exception as e: logger.error( diff --git a/sdk/python/feast/infra/compute_engines/ray/config.py b/sdk/python/feast/infra/compute_engines/ray/config.py index 0e25320651f..c6d74d262dd 100644 --- a/sdk/python/feast/infra/compute_engines/ray/config.py +++ b/sdk/python/feast/infra/compute_engines/ray/config.py @@ -17,9 +17,6 @@ class RayComputeEngineConfig(FeastConfigBaseModel): ray_address: Optional[str] = None """Ray cluster address. If None, uses local Ray cluster.""" - use_ray_cluster: bool = False - """Whether to use an existing Ray cluster.""" - staging_location: Optional[StrictStr] = None """Remote path for batch materialization jobs""" diff --git a/sdk/python/feast/infra/compute_engines/ray/feature_builder.py b/sdk/python/feast/infra/compute_engines/ray/feature_builder.py index 03a868c1779..07c5c6f1113 100644 --- a/sdk/python/feast/infra/compute_engines/ray/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/ray/feature_builder.py @@ -169,6 +169,9 @@ def build_output_nodes(self, view, final_node): def build_validation_node(self, view, input_node): """Build the validation node for feature validation.""" # TODO: Implement validation logic + logger.warning( + "Feature validation is not yet implemented for Ray compute engine." + ) return input_node def _build(self, view, input_nodes: Optional[List[DAGNode]]) -> DAGNode: diff --git a/sdk/python/feast/infra/compute_engines/ray/job.py b/sdk/python/feast/infra/compute_engines/ray/job.py index bfc0943ca95..b2e88f1d5c5 100644 --- a/sdk/python/feast/infra/compute_engines/ray/job.py +++ b/sdk/python/feast/infra/compute_engines/ray/job.py @@ -184,6 +184,7 @@ def to_remote_storage(self) -> list[str]: f"{self._config.batch_engine.staging_location}/{str(uuid.uuid4())}" ) self._result_dataset.write_parquet(output_uri) + logger.debug(f"Wrote result to {output_uri}") return [output_uri] except Exception as e: raise RuntimeError(f"Failed to write to remote storage: {e}") diff --git a/sdk/python/feast/infra/compute_engines/ray/nodes.py b/sdk/python/feast/infra/compute_engines/ray/nodes.py index c7b3ad701ae..5a5f04acee3 100644 --- a/sdk/python/feast/infra/compute_engines/ray/nodes.py +++ b/sdk/python/feast/infra/compute_engines/ray/nodes.py @@ -136,25 +136,11 @@ def execute(self, context: ExecutionContext) -> DAGValue: # Check if the feature dataset contains aggregated features (from aggregation node) # If so, we don't need point-in-time join logic - just simple join on entity keys - sample_data = feature_dataset.take(1) - is_aggregated = False - if sample_data: - if hasattr(sample_data[0], "columns"): - feature_cols = sample_data[0].columns.tolist() - else: - feature_cols = ( - list(sample_data[0].keys()) - if isinstance(sample_data[0], dict) - else [] - ) - - # Check for aggregated feature column patterns - is_aggregated = any( - col.startswith( - ("sum_", "avg_", "mean_", "count_", "min_", "max_", "std_", "var_") - ) - for col in feature_cols - ) + is_aggregated = ( + input_value.metadata.get("aggregated", False) + if input_value.metadata + else False + ) feature_size = feature_dataset.size_bytes() @@ -367,8 +353,7 @@ def execute(self, context: ExecutionContext) -> DAGValue: elif agg.function == "var": agg_dict[feature_name] = (agg.column, "var") else: - logger.warning(f"Unknown aggregation function: {agg.function}") - continue + raise ValueError(f"Unknown aggregation function: {agg.function}.") # Apply aggregations using pandas fallback (Ray's native groupby has compatibility issues) if self.group_by_keys and agg_dict: @@ -422,8 +407,7 @@ def _fallback_pandas_aggregation(self, dataset: Dataset, agg_dict: dict) -> Data elif function == "var": result = grouped[column].var() else: - logger.warning(f"Unknown aggregation function: {function}") - continue + raise ValueError(f"Unknown aggregation function: {function}.") result.name = feature_name agg_results.append(result) @@ -680,7 +664,7 @@ def write_batch_with_serialized_artifacts(batch: pd.DataFrame) -> pd.DataFrame: if getattr(feature_view, "online", False): # TODO: Implement proper online store writing with correct data format conversion logger.debug( - f"Online store writing not implemented yet for {len(batch)} rows" + "Online store writing not implemented yet for Ray compute engine" ) # Write to offline store if enabled diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py index f16766b7d17..8a82ec24a64 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py @@ -48,9 +48,13 @@ from feast.on_demand_feature_view import OnDemandFeatureView from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage, ValidationReference -from feast.type_map import feast_value_type_to_pandas_type, pa_to_feast_value_type +from feast.type_map import ( + convert_array_column, + convert_scalar_column, + feast_value_type_to_pandas_type, + pa_to_feast_value_type, +) from feast.utils import _get_column_names, make_df_tzaware, make_tzaware -from feast.value_type import ValueType logger = logging.getLogger(__name__) @@ -284,12 +288,12 @@ def convert_batch(batch: pd.DataFrame) -> pd.DataFrame: try: value_type = feature.dtype.to_value_type() if value_type.name.endswith("_LIST"): - batch[feat_name] = _convert_array_column( + batch[feat_name] = convert_array_column( batch[feat_name], value_type ) else: target_pandas_type = feast_value_type_to_pandas_type(value_type) - batch[feat_name] = _convert_scalar_column( + batch[feat_name] = convert_scalar_column( batch[feat_name], value_type, target_pandas_type ) except Exception as e: @@ -302,56 +306,12 @@ def convert_batch(batch: pd.DataFrame) -> pd.DataFrame: return _apply_to_data(data, convert_batch) -def _convert_scalar_column( - series: pd.Series, value_type: ValueType, target_pandas_type: str -) -> pd.Series: - """Convert a scalar feature column to the appropriate pandas type.""" - if value_type == ValueType.INT32: - return pd.to_numeric(series, errors="coerce").astype("Int32") - elif value_type == ValueType.INT64: - return pd.to_numeric(series, errors="coerce").astype("Int64") - elif value_type in [ValueType.FLOAT, ValueType.DOUBLE]: - return pd.to_numeric(series, errors="coerce").astype("float64") - elif value_type == ValueType.BOOL: - return series.astype("boolean") - elif value_type == ValueType.STRING: - return series.astype("string") - elif value_type == ValueType.UNIX_TIMESTAMP: - return pd.to_datetime(series, unit="s", errors="coerce") - else: - return series.astype(target_pandas_type) - - -def _convert_array_column(series: pd.Series, value_type: ValueType) -> pd.Series: - """Convert an array feature column to the appropriate type with proper empty array handling.""" - base_type_map = { - ValueType.INT32_LIST: np.int32, - ValueType.INT64_LIST: np.int64, - ValueType.FLOAT_LIST: np.float32, - ValueType.DOUBLE_LIST: np.float64, - ValueType.BOOL_LIST: np.bool_, - ValueType.STRING_LIST: object, - ValueType.BYTES_LIST: object, - ValueType.UNIX_TIMESTAMP_LIST: "datetime64[s]", - } - - target_dtype = base_type_map.get(value_type, object) - - def convert_array_item(item) -> Union[np.ndarray, Any]: - if item is None or (isinstance(item, list) and len(item) == 0): - if target_dtype == object: - return np.empty(0, dtype=object) - else: - return np.empty(0, dtype=target_dtype) - else: - return item - - return series.apply(convert_array_item) - - class RayOfflineStoreConfig(FeastConfigBaseModel): """ Configuration for the Ray Offline Store. + + For detailed configuration options and examples, see the documentation: + https://docs.feast.dev/reference/offline-stores/ray """ type: Literal[ @@ -359,7 +319,6 @@ class RayOfflineStoreConfig(FeastConfigBaseModel): ] = "ray" storage_path: Optional[str] = None ray_address: Optional[str] = None - use_ray_cluster: Optional[bool] = False # Optimization settings broadcast_join_threshold_mb: Optional[int] = 100 @@ -378,6 +337,7 @@ class RayOfflineStoreConfig(FeastConfigBaseModel): class RayResourceManager: """ Manages Ray cluster resources for optimal performance. + # See: https://docs.feast.dev/reference/offline-stores/ray#resource-management-and-testing """ def __init__(self, config: Optional[RayOfflineStoreConfig] = None) -> None: @@ -1265,7 +1225,7 @@ def _ensure_ray_initialized(config: Optional[RepoConfig] = None) -> None: if config and hasattr(config, "offline_store"): if isinstance(ray_config, RayOfflineStoreConfig): - if ray_config.use_ray_cluster and ray_config.ray_address: + if ray_config.ray_address: ray_init_kwargs["address"] = ray_config.ray_address else: ray_init_kwargs.update( diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_repo_configuration.py b/sdk/python/feast/infra/offline_stores/contrib/ray_repo_configuration.py index 32c0d6dbabd..6e1fa66b102 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ray_repo_configuration.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_repo_configuration.py @@ -30,7 +30,6 @@ def __init__(self, project_name: str, *args, **kwargs): type="ray", storage_path="/tmp/ray-storage", ray_address=None, - use_ray_cluster=False, broadcast_join_threshold_mb=25, max_parallelism_multiplier=1, target_partition_size_mb=16, diff --git a/sdk/python/feast/infra/ray_shared_utils.py b/sdk/python/feast/infra/ray_shared_utils.py index df8dfeb9fdb..6fa873ab6ae 100644 --- a/sdk/python/feast/infra/ray_shared_utils.py +++ b/sdk/python/feast/infra/ray_shared_utils.py @@ -141,7 +141,6 @@ def join_batch_with_features(batch: pd.DataFrame) -> pd.DataFrame: col for col in feature_cols if col in features.columns ] features_filtered = features[available_feature_cols].copy() - from .ray_shared_utils import normalize_timestamp_columns batch = normalize_timestamp_columns(batch, timestamp_field, inplace=True) features_filtered = normalize_timestamp_columns( diff --git a/sdk/python/feast/type_map.py b/sdk/python/feast/type_map.py index 6781c9a4301..c7f9096d9aa 100644 --- a/sdk/python/feast/type_map.py +++ b/sdk/python/feast/type_map.py @@ -1119,3 +1119,50 @@ def cb_columnar_type_to_feast_value_type(type_str: str) -> ValueType: if value == ValueType.UNKNOWN: print("unknown type:", type_str) return value + + +def convert_scalar_column( + series: pd.Series, value_type: ValueType, target_pandas_type: str +) -> pd.Series: + """Convert a scalar feature column to the appropriate pandas type.""" + if value_type == ValueType.INT32: + return pd.to_numeric(series, errors="coerce").astype("Int32") + elif value_type == ValueType.INT64: + return pd.to_numeric(series, errors="coerce").astype("Int64") + elif value_type in [ValueType.FLOAT, ValueType.DOUBLE]: + return pd.to_numeric(series, errors="coerce").astype("float64") + elif value_type == ValueType.BOOL: + return series.astype("boolean") + elif value_type == ValueType.STRING: + return series.astype("string") + elif value_type == ValueType.UNIX_TIMESTAMP: + return pd.to_datetime(series, unit="s", errors="coerce") + else: + return series.astype(target_pandas_type) + + +def convert_array_column(series: pd.Series, value_type: ValueType) -> pd.Series: + """Convert an array feature column to the appropriate type with proper empty array handling.""" + base_type_map = { + ValueType.INT32_LIST: np.int32, + ValueType.INT64_LIST: np.int64, + ValueType.FLOAT_LIST: np.float32, + ValueType.DOUBLE_LIST: np.float64, + ValueType.BOOL_LIST: np.bool_, + ValueType.STRING_LIST: object, + ValueType.BYTES_LIST: object, + ValueType.UNIX_TIMESTAMP_LIST: "datetime64[s]", + } + + target_dtype = base_type_map.get(value_type, object) + + def convert_array_item(item) -> Union[np.ndarray, Any]: + if item is None or (isinstance(item, list) and len(item) == 0): + if target_dtype == object: + return np.empty(0, dtype=object) + else: + return np.empty(0, dtype=target_dtype) # type: ignore + else: + return item + + return series.apply(convert_array_item) diff --git a/sdk/python/tests/integration/compute_engines/ray_compute/repo_configuration.py b/sdk/python/tests/integration/compute_engines/ray_compute/repo_configuration.py index 6b74859022f..37d0d020ccd 100644 --- a/sdk/python/tests/integration/compute_engines/ray_compute/repo_configuration.py +++ b/sdk/python/tests/integration/compute_engines/ray_compute/repo_configuration.py @@ -19,7 +19,6 @@ def get_ray_compute_engine_test_config() -> IntegrationTestRepoConfig: offline_store_creator=RayDataSourceCreator, batch_engine={ "type": "ray.engine", - "use_ray_cluster": False, "max_workers": 1, "enable_optimization": True, }, diff --git a/sdk/python/tests/integration/compute_engines/ray_compute/test_compute.py b/sdk/python/tests/integration/compute_engines/ray_compute/test_compute.py index 73cc6d19cd1..e7060b4a756 100644 --- a/sdk/python/tests/integration/compute_engines/ray_compute/test_compute.py +++ b/sdk/python/tests/integration/compute_engines/ray_compute/test_compute.py @@ -154,7 +154,6 @@ def test_ray_compute_engine_config(): """Test Ray compute engine configuration.""" config = RayComputeEngineConfig( type="ray.engine", - use_ray_cluster=True, ray_address="ray://localhost:10001", broadcast_join_threshold_mb=200, enable_distributed_joins=True, @@ -167,7 +166,6 @@ def test_ray_compute_engine_config(): ) assert config.type == "ray.engine" - assert config.use_ray_cluster is True assert config.ray_address == "ray://localhost:10001" assert config.broadcast_join_threshold_mb == 200 assert config.window_size_timedelta == timedelta(hours=2) diff --git a/sdk/python/tests/unit/infra/compute_engines/ray_compute/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/ray_compute/test_nodes.py index c6cfc13280d..e8c40d43099 100644 --- a/sdk/python/tests/unit/infra/compute_engines/ray_compute/test_nodes.py +++ b/sdk/python/tests/unit/infra/compute_engines/ray_compute/test_nodes.py @@ -61,7 +61,6 @@ def ray_config(): """Create Ray compute engine configuration for testing.""" return RayComputeEngineConfig( type="ray.engine", - use_ray_cluster=False, max_workers=2, enable_optimization=True, broadcast_join_threshold_mb=50, @@ -297,7 +296,6 @@ def test_ray_config_validation(): # Test valid configuration config = RayComputeEngineConfig( type="ray.engine", - use_ray_cluster=False, max_workers=4, enable_optimization=True, broadcast_join_threshold_mb=100,