diff --git a/.flake8 b/.flake8
index 8fb7c1063..6e1c59faf 100644
--- a/.flake8
+++ b/.flake8
@@ -8,8 +8,9 @@ max-line-length = 120
# N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP)
# E731 allow usage of assigning lambda expressions
# N803,N806 allow caps and mixed case in function params. This is to work with Triton kernel coding style.
+# E704 ignored to allow black's formatting of Protocol stub methods (def method(self) -> None: ...)
ignore =
- E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731,N803,N806
+ E203,E305,E402,E501,E704,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731,N803,N806
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit
EXE001,
diff --git a/.github/ISSUE_TEMPLATE/feature-request.yml b/.github/ISSUE_TEMPLATE/feature-request.yml
new file mode 100644
index 000000000..cfe8204b7
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature-request.yml
@@ -0,0 +1,27 @@
+name: ✨ Feature Request
+description: Suggest a new feature or enhancement for this project
+
+body:
+- type: markdown
+ attributes:
+ value: >
+ #### Before submitting a feature request, please search through [existing issues](https://github.com/meta-pytorch/forge/issues?q=is%3Aissue+sort%3Acreated-desc+) to see if something similar has already been proposed.
+- type: textarea
+ attributes:
+ label: Context/Motivation
+ description: |
+ Describe the problem you're trying to solve or the use case for this feature. Include any relevant links and context.
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: Pseudo-code + acceptance criteria [Optional]
+ description: |
+ Provide a rough sketch of what the API or implementation might look like. This helps us understand your vision for how the feature would work.
+ Also, if possible, include what would need to be true for this feature to be considered complete.
+ validations:
+ required: false
+- type: markdown
+ attributes:
+ value: >
+ Thanks for contributing 🎉!
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
new file mode 100644
index 000000000..056daa140
--- /dev/null
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -0,0 +1,3 @@
+## Description
+
+## Test plan
diff --git a/.github/packaging/pre_build_cpu.sh b/.github/packaging/pre_build_cpu.sh
index 520bdedb1..60913449d 100644
--- a/.github/packaging/pre_build_cpu.sh
+++ b/.github/packaging/pre_build_cpu.sh
@@ -4,7 +4,13 @@ set -euxo pipefail
# Builds vLLM
# This script builds vLLM and places its wheel into dist/.
-VLLM_BRANCH="v0.10.0"
+# Script runs relative to forge root
+CURRENT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+echo "current dir is $CURRENT_DIR"
+VERSIONS_FILE="$CURRENT_DIR/../../assets/versions.sh"
+echo "versions file is $VERSIONS_FILE"
+source "$VERSIONS_FILE"
+
BUILD_DIR="$HOME/forge-build"
# Push other files to the dist folder
@@ -18,7 +24,7 @@ echo "wheel dir is $WHL_DIR"
build_vllm() {
cd "$BUILD_DIR"
- git clone https://github.com/vllm-project/vllm.git --branch $VLLM_BRANCH
+ git clone https://github.com/vllm-project/vllm.git --branch $VLLM_VERSION
cd "$BUILD_DIR/vllm"
python use_existing_torch.py
diff --git a/.github/packaging/pre_build_gpu.sh b/.github/packaging/pre_build_gpu.sh
deleted file mode 100644
index d81f52782..000000000
--- a/.github/packaging/pre_build_gpu.sh
+++ /dev/null
@@ -1,71 +0,0 @@
-#!/bin/bash
-set -euxo pipefail
-
-# Builds Monarch
-# This script builds Monarch and places its wheel into dist/.
-
-MONARCH_COMMIT="265034a29ec3fb35919f4a9c23c65f2f4237190d"
-BUILD_DIR="$HOME/forge-build"
-
-# Push other files to the dist folder
-WHL_DIR="${GITHUB_WORKSPACE}/wheels/dist"
-
-mkdir -p $BUILD_DIR
-mkdir -p $WHL_DIR
-echo "build dir is $BUILD_DIR"
-echo "wheel dir is $WHL_DIR"
-
-build_monarch() {
- # Get Rust build related pieces
- if ! command -v rustup &> /dev/null; then
- echo "getting rustup"
- curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
- export PATH="$HOME/.cargo/bin:$PATH"
- echo "$HOME/.cargo/bin" >> $GITHUB_PATH
- fi
-
- rustup toolchain install nightly
- rustup default nightly
-
- if command -v dnf &>/dev/null; then
- dnf install -y clang-devel \
- libibverbs rdma-core libmlx5 libibverbs-devel rdma-core-devel fmt-devel \
- libunwind-devel
- elif command -v apt-get &>/dev/null; then
- apt-get update
- apt-get install -y clang libunwind-dev \
- libibverbs-dev librdmacm-dev libfmt-dev
- fi
-
- cd "$BUILD_DIR"
- git clone https://github.com/meta-pytorch/monarch.git
- cd "$BUILD_DIR/monarch"
- git checkout $MONARCH_COMMIT
-
- pip install -r build-requirements.txt
- export USE_TENSOR_ENGINE=1
- export RUST_BACKTRACE=1
- export CARGO_TERM_VERBOSE=true
- export CARGO_TERM_COLOR=always
- pip wheel --no-build-isolation --no-deps . -w "$WHL_DIR"
-}
-
-append_date() {
- cd ${GITHUB_WORKSPACE}/${REPOSITORY}
- # Appends the current date and time to the Forge wheel
- version_file="assets/version.txt"
- init_file="src/forge/__init__.py"
- if [[ -n "$BUILD_VERSION" ]]; then
- # Update the version in version.txt
- echo "$BUILD_VERSION" > "$version_file"
- # Create a variable named __version__ at the end of __init__.py
- echo "__version__ = \"$BUILD_VERSION\"" >> "$init_file"
- else
- echo "Error: BUILD_VERSION environment variable is not set or empty."
- exit 1
- fi
-}
-
-
-build_monarch
-append_date
\ No newline at end of file
diff --git a/.github/packaging/vllm_reqs_12_8.txt b/.github/packaging/vllm_reqs_12_8.txt
new file mode 100644
index 000000000..d1ba5e385
--- /dev/null
+++ b/.github/packaging/vllm_reqs_12_8.txt
@@ -0,0 +1,139 @@
+# This file was generated by running ./scripts/generate_vllm_reqs.sh
+aiohappyeyeballs==2.6.1
+aiohttp==3.13.1
+aiosignal==1.4.0
+annotated-types==0.7.0
+anyio==4.11.0
+astor==0.8.1
+async-timeout==5.0.1
+attrs==25.4.0
+blake3==1.0.8
+cachetools==6.2.1
+cbor2==5.7.0
+certifi==2025.10.5
+cffi==2.0.0
+charset-normalizer==3.4.4
+click==8.2.1
+cloudpickle==3.1.1
+cmake==4.1.0
+compressed-tensors==0.10.2
+cupy-cuda12x==13.6.0
+depyf==0.19.0
+dill==0.4.0
+diskcache==5.6.3
+distro==1.9.0
+dnspython==2.8.0
+einops==0.8.1
+email-validator==2.3.0
+exceptiongroup==1.3.0
+fastapi==0.119.1
+fastapi-cli==0.0.14
+fastapi-cloud-cli==0.3.1
+fastrlock==0.8.3
+filelock==3.19.1
+frozenlist==1.8.0
+fsspec==2025.9.0
+gguf==0.17.1
+h11==0.16.0
+hf-xet==1.1.10
+httpcore==1.0.9
+httptools==0.7.1
+httpx==0.28.1
+huggingface-hub==0.35.3
+idna==3.11
+interegular==0.3.3
+Jinja2==3.1.6
+jiter==0.11.1
+jsonschema==4.25.1
+jsonschema-specifications==2025.9.1
+lark==1.2.2
+llguidance==0.7.30
+llvmlite==0.44.0
+lm-format-enforcer==0.10.12
+markdown-it-py==4.0.0
+MarkupSafe==2.1.5
+mdurl==0.1.2
+mistral_common==1.8.5
+mpmath==1.3.0
+msgpack==1.1.2
+msgspec==0.19.0
+multidict==6.7.0
+networkx==3.3
+ninja==1.13.0
+numba==0.61.2
+numpy==2.2.6
+nvidia-cublas-cu12==12.8.4.1
+nvidia-cuda-cupti-cu12==12.8.90
+nvidia-cuda-nvrtc-cu12==12.8.93
+nvidia-cuda-runtime-cu12==12.8.90
+nvidia-cudnn-cu12==9.10.2.21
+nvidia-cufft-cu12==11.3.3.83
+nvidia-cufile-cu12==1.13.1.3
+nvidia-curand-cu12==10.3.9.90
+nvidia-cusolver-cu12==11.7.3.90
+nvidia-cusparse-cu12==12.5.8.93
+nvidia-cusparselt-cu12==0.7.1
+nvidia-nccl-cu12==2.27.5
+nvidia-nvjitlink-cu12==12.8.93
+nvidia-nvshmem-cu12==3.3.20
+nvidia-nvtx-cu12==12.8.90
+openai==1.90.0
+opencv-python-headless==4.12.0.88
+outlines_core==0.2.10
+packaging==25.0
+partial-json-parser==0.2.1.1.post6
+pillow==12.0.0
+prometheus-fastapi-instrumentator==7.1.0
+prometheus_client==0.23.1
+propcache==0.4.1
+protobuf==6.33.0
+psutil==7.1.1
+py-cpuinfo==9.0.0
+pybase64==1.4.2
+pycountry==24.6.1
+pycparser==2.23
+pydantic==2.12.3
+pydantic-extra-types==2.10.6
+pydantic_core==2.41.4
+Pygments==2.19.2
+python-dotenv==1.1.1
+python-json-logger==4.0.0
+python-multipart==0.0.20
+PyYAML==6.0.3
+pyzmq==27.1.0
+ray==2.50.1
+referencing==0.37.0
+regex==2025.10.23
+requests==2.32.5
+rich==14.2.0
+rich-toolkit==0.15.1
+rignore==0.7.1
+rpds-py==0.27.1
+safetensors==0.6.2
+scipy==1.15.3
+sentencepiece==0.2.1
+sentry-sdk==2.42.1
+setuptools-scm==9.2.2
+shellingham==1.5.4
+sniffio==1.3.1
+soundfile==0.13.1
+soxr==1.0.0
+starlette==0.48.0
+sympy==1.14.0
+tiktoken==0.12.0
+tokenizers==0.22.1
+tomli==2.3.0
+torch==2.9.0+cu128
+tqdm==4.67.1
+transformers==4.57.1
+triton==3.5.0
+typer==0.20.0
+typing-inspection==0.4.2
+typing_extensions==4.15.0
+urllib3==2.5.0
+uvicorn==0.38.0
+uvloop==0.22.1
+watchfiles==1.1.1
+websockets==15.0.1
+xgrammar==0.1.21
+yarl==1.22.0
diff --git a/.github/packaging/vllm_reqs_12_9.txt b/.github/packaging/vllm_reqs_12_9.txt
new file mode 100644
index 000000000..aad2b28bd
--- /dev/null
+++ b/.github/packaging/vllm_reqs_12_9.txt
@@ -0,0 +1,147 @@
+# These requirements were generated by running steps 1-3 of scripts/build_wheels.sh
+# then running pip freeze and manually removing the vllm dependency.
+# The intention of this file is to use these known requirements for a fixed
+# vLLM build to supplement a vLLM install from download.pytorch.org without
+# resorting to --extra-index-url https://download.pytorch.org/whl/nightly to find
+# vLLM dependencies (as this results in a ResolutionTooDeep error from pip).
+# See the file .github/workflows/gpu_test.yaml for an E2E forge installation using this approach.
+# TODO: this should be done way less hackily
+aiohappyeyeballs==2.6.1
+aiohttp==3.13.0
+aiosignal==1.4.0
+annotated-types==0.7.0
+anyio==4.11.0
+astor==0.8.1
+async-timeout==5.0.1
+attrs==25.4.0
+blake3==1.0.7
+cachetools==6.2.0
+cbor2==5.7.0
+certifi==2025.10.5
+cffi==2.0.0
+charset-normalizer==3.4.3
+click==8.3.0
+cloudpickle==3.1.1
+cmake==4.1.0
+compressed-tensors==0.10.2
+cupy-cuda12x==13.6.0
+depyf==0.19.0
+dill==0.4.0
+diskcache==5.6.3
+distro==1.9.0
+dnspython==2.8.0
+einops==0.8.1
+email-validator==2.3.0
+exceptiongroup==1.3.0
+fastapi==0.118.3
+fastapi-cli==0.0.13
+fastapi-cloud-cli==0.3.1
+fastrlock==0.8.3
+filelock==3.19.1
+frozenlist==1.8.0
+fsspec==2025.9.0
+gguf==0.17.1
+h11==0.16.0
+hf-xet==1.1.10
+httpcore==1.0.9
+httptools==0.7.1
+httpx==0.28.1
+huggingface-hub==0.35.3
+idna==3.10
+interegular==0.3.3
+Jinja2==3.1.6
+jiter==0.11.0
+jsonschema==4.25.1
+jsonschema-specifications==2025.9.1
+lark==1.2.2
+llguidance==0.7.30
+llvmlite==0.44.0
+lm-format-enforcer==0.10.12
+markdown-it-py==4.0.0
+MarkupSafe==3.0.2
+mdurl==0.1.2
+mistral_common==1.8.5
+mpmath==1.3.0
+msgpack==1.1.2
+msgspec==0.19.0
+multidict==6.7.0
+networkx==3.4.2
+ninja==1.13.0
+numba==0.61.2
+numpy==2.2.6
+nvidia-cublas-cu12==12.9.1.4
+nvidia-cuda-cupti-cu12==12.9.79
+nvidia-cuda-nvrtc-cu12==12.9.86
+nvidia-cuda-runtime-cu12==12.9.79
+nvidia-cudnn-cu12==9.10.2.21
+nvidia-cufft-cu12==11.4.1.4
+nvidia-cufile-cu12==1.14.1.1
+nvidia-curand-cu12==10.3.10.19
+nvidia-cusolver-cu12==11.7.5.82
+nvidia-cusparse-cu12==12.5.10.65
+nvidia-cusparselt-cu12==0.7.1
+nvidia-nccl-cu12==2.27.5
+nvidia-nvjitlink-cu12==12.9.86
+nvidia-nvshmem-cu12==3.3.20
+nvidia-nvtx-cu12==12.9.79
+openai==1.90.0
+opencv-python-headless==4.12.0.88
+outlines_core==0.2.10
+packaging==25.0
+partial-json-parser==0.2.1.1.post6
+pillow==11.3.0
+prometheus-fastapi-instrumentator==7.1.0
+prometheus_client==0.23.1
+propcache==0.4.1
+protobuf==6.32.1
+psutil==7.1.0
+py-cpuinfo==9.0.0
+pybase64==1.4.2
+pycountry==24.6.1
+pycparser==2.23
+pydantic==2.12.0
+pydantic-extra-types==2.10.6
+pydantic_core==2.41.1
+Pygments==2.19.2
+python-dotenv==1.1.1
+python-json-logger==4.0.0
+python-multipart==0.0.20
+pytorch-triton==3.4.0+gitf7888497
+PyYAML==6.0.3
+pyzmq==27.1.0
+ray==2.49.2
+referencing==0.36.2
+regex==2025.9.18
+requests==2.32.5
+rich==14.2.0
+rich-toolkit==0.15.1
+rignore==0.7.0
+rpds-py==0.27.1
+safetensors==0.6.2
+scipy==1.15.3
+sentencepiece==0.2.1
+sentry-sdk==2.41.0
+setuptools-scm==9.2.0
+shellingham==1.5.4
+sniffio==1.3.1
+soundfile==0.13.1
+soxr==1.0.0
+starlette==0.48.0
+sympy==1.14.0
+tiktoken==0.12.0
+tokenizers==0.22.1
+tomli==2.3.0
+torch==2.9.0.dev20250905+cu129
+tqdm==4.67.1
+transformers==4.57.0
+triton==3.4.0
+typer==0.19.2
+typing-inspection==0.4.2
+typing_extensions==4.15.0
+urllib3==2.5.0
+uvicorn==0.37.0
+uvloop==0.21.0
+watchfiles==1.1.0
+websockets==15.0.1
+xgrammar==0.1.21
+yarl==1.22.0
diff --git a/.github/workflows/build_vllm.yaml b/.github/workflows/build_vllm.yaml
index 5e2f0db61..442a3739d 100644
--- a/.github/workflows/build_vllm.yaml
+++ b/.github/workflows/build_vllm.yaml
@@ -12,39 +12,71 @@ permissions:
jobs:
build:
- name: forge-cu129-nightly
- uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@28a1b658404f17c8eabde5f7fe25ae3ac826fae6
+ name: forge-cu128-nightly
+ uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@vllm-push
strategy:
fail-fast: false
with:
repository: meta-pytorch/forge
ref: ""
test-infra-repository: pytorch/test-infra
- test-infra-ref: 28a1b658404f17c8eabde5f7fe25ae3ac826fae6
+ test-infra-ref: vllm-push
run-smoke-test: false
- wheel-upload-path: whl/preview/forge
+ wheel-nightly-policy: gha_workflow_preview_build_wheels
+ wheel-upload-path: whl/preview/forge/
package-name: forge
+ channel: test # Hack here to make sure stable pytorch is used
build-matrix: |
{
"include": [
{
"python_version": "3.10",
"gpu_arch_type": "cpu",
- "gpu_arch_version": "12.9",
- "desired_cuda": "cu129",
- "container_image": "pytorch/manylinux2_28-builder:cuda12.9",
+ "gpu_arch_version": "12.8",
+ "desired_cuda": "cu128",
+ "container_image": "pytorch/manylinux2_28-builder:cuda12.8",
"package_type": "manywheel",
- "build_name": "manywheel-py3_10-cuda12_9",
+ "build_name": "manywheel-py3_10-cuda12_8",
"validation_runner": "linux.12xlarge.memory",
- "installation": "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129",
- "channel": "nightly",
+ "installation": "pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128",
+ "channel": "test",
"upload_to_base_bucket": "no",
- "stable_version": "2.8.0",
+ "stable_version": "2.9.0",
"use_split_build": false
- }
+ },
+ {
+ "python_version": "3.11",
+ "gpu_arch_type": "cpu",
+ "gpu_arch_version": "12.8",
+ "desired_cuda": "cu128",
+ "container_image": "pytorch/manylinux2_28-builder:cuda12.8",
+ "package_type": "manywheel",
+ "build_name": "manywheel-py3_11-cuda12_8",
+ "validation_runner": "linux.12xlarge.memory",
+ "installation": "pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128",
+ "channel": "test",
+ "upload_to_base_bucket": "no",
+ "stable_version": "2.9.0",
+ "use_split_build": false
+ },
+ {
+ "python_version": "3.12",
+ "gpu_arch_type": "cpu",
+ "gpu_arch_version": "12.8",
+ "desired_cuda": "cu128",
+ "container_image": "pytorch/manylinux2_28-builder:cuda12.8",
+ "package_type": "manywheel",
+ "build_name": "manywheel-py3_12-cuda12_8",
+ "validation_runner": "linux.12xlarge.memory",
+ "installation": "pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128",
+ "channel": "test",
+ "upload_to_base_bucket": "no",
+ "stable_version": "2.9.0",
+ "use_split_build": false
+ },
]
}
pre-script: .github/packaging/pre_build_cpu.sh
post-script: .github/packaging/post_build_script.sh
trigger-event: ${{ github.event_name }}
- build-platform: 'python-build-package'
\ No newline at end of file
+ build-platform: 'python-build-package'
diff --git a/.github/workflows/build_wheels.yaml b/.github/workflows/build_wheels.yaml
deleted file mode 100644
index 05f5e3997..000000000
--- a/.github/workflows/build_wheels.yaml
+++ /dev/null
@@ -1,50 +0,0 @@
-name: Build nightly wheels and publish to PyTorch Index
-
-on:
- push:
- branches:
- - nightly
- workflow_dispatch:
-
-permissions:
- id-token: write
- contents: read
-
-jobs:
- build:
- name: forge-cu129-nightly
- uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@28a1b658404f17c8eabde5f7fe25ae3ac826fae6
- strategy:
- fail-fast: false
- with:
- repository: meta-pytorch/forge
- ref: ""
- test-infra-repository: pytorch/test-infra
- test-infra-ref: 28a1b658404f17c8eabde5f7fe25ae3ac826fae6
- run-smoke-test: false
- wheel-upload-path: whl/preview/forge
- package-name: forge
- build-matrix: |
- {
- "include": [
- {
- "python_version": "3.10",
- "gpu_arch_type": "cuda",
- "gpu_arch_version": "12.9",
- "desired_cuda": "cu129",
- "container_image": "pytorch/manylinux2_28-builder:cuda12.9",
- "package_type": "manywheel",
- "build_name": "manywheel-py3_10-cuda12_9",
- "validation_runner": "linux.4xlarge.nvidia.gpu",
- "installation": "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129",
- "channel": "nightly",
- "upload_to_base_bucket": "no",
- "stable_version": "2.8.0",
- "use_split_build": false
- }
- ]
- }
- pre-script: .github/packaging/pre_build_gpu.sh
- post-script: .github/packaging/post_build_script.sh
- trigger-event: ${{ github.event_name }}
- build-platform: 'python-build-package'
\ No newline at end of file
diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
index a50ed3200..2ea947224 100644
--- a/.github/workflows/docs.yml
+++ b/.github/workflows/docs.yml
@@ -9,8 +9,9 @@ on:
jobs:
build-docs:
+ if: github.repository_owner == 'meta-pytorch'
name: Build Documentation
- runs-on: ubuntu-latest
+ runs-on: linux.g5.4xlarge.nvidia.gpu
timeout-minutes: 30
steps:
- name: Checkout
@@ -24,88 +25,97 @@ jobs:
miniconda-version: "latest"
activate-environment: test
python-version: '3.10'
- auto-activate-base: false
- - name: Verify conda environment
- shell: bash -l {0}
- run: |
- conda info
- which python
- which conda
+ auto-activate: false
- name: Update pip
shell: bash -l {0}
run: python -m pip install --upgrade pip
- - name: Install pytorch
- shell: bash -l {0}
- run: python -m pip install torch==2.9.0.dev20250826 --extra-index-url https://download.pytorch.org/whl/nightly/cpu
- - name: Install monarch
- shell: bash -l {0}
- run: python -m pip install monarch-no-torch==0.1.0.dev20250826 --find-links assets/ci
- name: Install torchforge
shell: bash -l {0}
- env:
- GH_TOKEN: ${{ github.token }}
- run: ./scripts/install.sh
- - name: Install docs dependencies
- shell: bash -l {0}
- run: python -m pip install -r docs/requirements.txt
+ run: pip install uv && uv pip install . && uv pip install .[docs]
- name: Build docs
shell: bash -l {0}
working-directory: docs
- run: make html --keep-going SPHINXOPTS='-W'
+ run: make html
- name: Upload docs artifact
uses: actions/upload-artifact@v4
with:
name: docs
path: docs/build/html/
- # doc-preview:
- # runs-on: [ubuntu-latest]
- # needs: build-docs
- # if: ${{ github.event_name == 'pull_request' }}
- # steps:
- # - name: Checkout
- # uses: actions/checkout@v4
- # - name: Download artifact
- # uses: actions/download-artifact@v4
- # with:
- # name: docs
- # path: docs
- # - name: Add noindex to preview docs
- # run: |
- # echo "Adding noindex meta tag to prevent search engine indexing of preview docs"
- # find docs -name "*.html" -print0 | xargs -0 sed -i 's/
/\n /'
- # - name: Upload docs preview
- # uses: seemethere/upload-artifact-s3@v5
- # if: ${{ github.event_name == 'pull_request' }}
- # with:
- # retention-days: 14
- # s3-bucket: doc-previews
- # if-no-files-found: error
- # path: docs
- # s3-prefix: meta-pytorch/forge/${{ github.event.pull_request.number }}
+ doc-preview:
+ runs-on: linux.large
+ needs: build-docs
+ if: ${{ github.event_name == 'pull_request' }}
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+ - name: Download artifact
+ uses: actions/download-artifact@v4
+ with:
+ name: docs
+ path: docs
+ - name: Add noindex to preview docs
+ run: |
+ echo "Adding noindex meta tag to prevent search engine indexing of preview docs"
+ find docs -name "*.html" -print0 | xargs -0 sed -i 's//\n /'
+ - name: Upload docs preview
+ uses: seemethere/upload-artifact-s3@v5
+ if: ${{ github.event_name == 'pull_request' }}
+ with:
+ retention-days: 14
+ s3-bucket: doc-previews
+ if-no-files-found: error
+ path: docs
+ s3-prefix: meta-pytorch/torchforge/${{ github.event.pull_request.number }}
+
+ upload:
+ runs-on: ubuntu-latest
+ permissions:
+ # Grant write permission here so that the doc can be pushed to gh-pages branch
+ contents: write
+ needs: build-docs
+ if: github.repository == 'meta-pytorch/torchforge' && github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/v') || github.event_name == 'workflow_dispatch')
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+ with:
+ ref: gh-pages
+ persist-credentials: true
+ - name: Download artifact
+ uses: actions/download-artifact@v4
+ with:
+ name: docs
+ path: docs
+ #- name: Add no-index tag
+ # run: |
+ # REF_NAME=$(echo "${{ github.ref }}")
+ # echo "Ref name: ${REF_NAME}"
+ # if [[ "${{ github.ref }}" == 'refs/heads/main' ]]; then
+ # find docs -name "*.html" -print0 | xargs -0 sed -i '//a \ \ ';
+ # fi
+ - name: Move and commit changes
+ run: |
+ set -euo pipefail
+ # Get github.ref for the output doc folder. By default "main"
+ # If matches a tag like refs/tags/v1.12.0-rc3 or
+ # refs/tags/v1.12.0 convert to 1.12
+ GITHUB_REF=${{ github.ref }}
- deploy-docs:
- needs: build-docs
- if: github.ref == 'refs/heads/main'
- permissions:
- pages: write
- id-token: write
- environment:
- name: github-pages
- url: ${{ steps.deployment.outputs.page_url }}
- runs-on: ubuntu-latest
- steps:
- - name: Download build artifact
- uses: actions/download-artifact@v4
- with:
- name: docs
- path: .
+ # Convert refs/tags/v1.12.0rc3 into 1.12.
+ # Adopted from https://github.com/pytorch/pytorch/blob/main/.github/workflows/_docs.yml#L150C11-L155C13
+ if [[ "${GITHUB_REF}" =~ ^refs/tags/v([0-9]+\.[0-9]+)\.* ]]; then
+ TARGET_FOLDER="${BASH_REMATCH[1]}"
+ else
+ TARGET_FOLDER="main"
+ fi
+ echo "Target Folder: ${TARGET_FOLDER}"
- - name: Upload Pages artifact
- uses: actions/upload-pages-artifact@v3
- with:
- path: .
+ mkdir -p "${TARGET_FOLDER}"
+ rm -rf "${TARGET_FOLDER}"/*
+ mv docs/* "${TARGET_FOLDER}"
- - name: Deploy to GitHub Pages
- id: deployment
- uses: actions/deploy-pages@v4
+ git config user.name 'pytorchbot'
+ git config user.email 'soumith+bot@pytorch.org'
+ git add "${TARGET_FOLDER}" || true
+ git commit -m "auto-generating sphinx docs" || true
+ git push -f
diff --git a/.github/workflows/unit_test.yaml b/.github/workflows/gpu_test.yaml
similarity index 51%
rename from .github/workflows/unit_test.yaml
rename to .github/workflows/gpu_test.yaml
index 9a839f32d..71455c122 100644
--- a/.github/workflows/unit_test.yaml
+++ b/.github/workflows/gpu_test.yaml
@@ -1,13 +1,27 @@
-name: Unit Test
+name: Unit Tests (GPU)
on:
+ push:
+ branches: [ main ]
pull_request:
+ workflow_dispatch:
+concurrency:
+ group: gpu-test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
+ cancel-in-progress: true
+
+permissions:
+ id-token: write
+ contents: read
+
+defaults:
+ run:
+ shell: bash -l -eo pipefail {0}
jobs:
- unit_tests:
- runs-on: ubuntu-latest
- timeout-minutes: 15
+ gpu_test:
+ if: github.repository_owner == 'meta-pytorch'
+ runs-on: linux.g5.12xlarge.nvidia.gpu
strategy:
matrix:
python-version: ['3.10', '3.11', '3.12']
@@ -23,18 +37,8 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Update pip
run: python -m pip install --upgrade pip
- - name: Install pytorch
- run: python -m pip install torch==2.9.0.dev20250826 --extra-index-url https://download.pytorch.org/whl/nightly/cpu
- - name: Install monarch
- run: python -m pip install monarch-no-torch==0.1.0.dev20250826 --find-links assets/ci
- - name: Install torchstore
- run: pip install assets/wheels/torchstore-0.1.0-py3-none-any.whl
- - name: Install torchtitan
- run: |
- pip install assets/wheels/torchtitan-0.1.0-py3-none-any.whl
- pip install tyro
- - name: Install dependencies
- run: python -m pip install --no-build-isolation -e ".[dev]"
+ - name: Install torchforge
+ run: pip install uv && uv pip install . && uv pip install .[dev]
- name: Run unit tests with coverage
# TODO add all tests
run: pytest tests/unit_tests --cov=. --cov-report=xml --durations=20 -vv
diff --git a/.gitignore b/.gitignore
index 160cb21a4..c952405d6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -153,6 +153,7 @@ docs/source/generated_examples/
docs/source/gen_modules/
docs/source/generated/
docs/source/sg_execution_times.rst
+docs/source/tutorials/*
# pytorch-sphinx-theme gets installed here
docs/src
@@ -199,3 +200,7 @@ assets/wheels/vllm*.whl
# DCP artifacts
forge_dcp_tmp/
demo_top_down.md
+
+
+# enroot / sqsh
+*.sqsh
diff --git a/.meta/mast/README.md b/.meta/mast/README.md
new file mode 100644
index 000000000..4f4ccb2b3
--- /dev/null
+++ b/.meta/mast/README.md
@@ -0,0 +1,125 @@
+# Forge MAST Environment Setup
+
+A simple setup script to automatically configure your environment for running Forge with MAST jobs.
+This only applies to Meta internal users.
+
+## Quick Start
+
+### 1. Run the Setup Script
+
+The `env_setup.sh` script will automatically:
+- ✅ Activate and configure the required conda environment
+- ✅ Clone/update the Forge repository
+- ✅ Install Forge package dependencies
+- ✅ Mount the required oilfs workspace to `/mnt/wsfuse`
+- ✅ Configure your environment for MAST job submission
+
+```bash
+# Make the script executable
+chmod +x .meta/mast/env_setup.sh
+
+# Run the setup
+source .meta/mast/env_setup.sh
+
+```
+
+### 2. Submit MAST job
+
+Use the launch script to submit a MAST job:
+
+```bash
+# Make the launch script executable (first time only)
+chmod +x .meta/mast/launch.sh
+
+# Launch a job with your desired config
+./.meta/mast/launch.sh .meta/mast/qwen3_1_7b_mast.yaml
+```
+
+The launch script will automatically:
+- Navigate to the forge root directory
+- Reinstall the forge package with your latest changes
+- Set the correct PYTHONPATH
+- Launch the MAST job with the specified config
+
+You can run it from anywhere, and it will figure out the correct paths.
+
+
+## How MAST Launcher Works
+
+The MAST launcher uses a two-stage architecture to run training jobs:
+
+### Stage 1: Detached Mode (Local Machine)
+
+When you run `./.meta/mast/launch.sh`, the `main.py` script starts in **detached mode**:
+
+1. The launcher creates a MAST job with all the worker roles (GPU hosts)
+2. It also creates a special **client role** - a CPU-only role that will run inside MAST
+3. The client role's entrypoint is set to `client_bootstrap.sh`
+4. All CLI arguments you pass are forwarded to the client role
+
+At this point, the job is submitted to MAST and your local script exits. Everything now runs in the cluster.
+
+### Stage 2: Remote Mode (Inside MAST)
+
+The `client_bootstrap.sh` script runs inside the MAST client role and:
+
+1. Calls `main.py` again, but now with `--mode=remote`
+2. In **remote mode**, the script:
+ - Mounts the OilFS workspace
+ - Initializes the provisioner to connect to worker roles
+ - Runs the actual training workload (e.g., GRPO)
+
+This architecture allows the entire training workflow to run inside MAST without requiring a persistent connection from your local machine.
+
+### Key Files
+
+- **`main.py`**: Entry point that handles both detached and remote modes
+- **`client_bootstrap.sh`**: Entrypoint for the client role in MAST
+- **`launcher.py`**: Creates the MAST job specification and handles role configuration
+
+
+## Managing HuggingFace Models in MAST
+
+### The Problem: No Internet Access
+
+MAST compute nodes cannot access the internet, which means they cannot download models directly from HuggingFace. To work around this, we store all HuggingFace models and cache data on OilFS at `/mnt/wsfuse/teamforge/hf`, which is accessible from MAST.
+
+### Solution: Two-Step Process
+
+You need to perform both steps below to ensure models work correctly in MAST:
+
+#### 1. Download Model Weights to OilFS
+
+First, download the model weights directly to the OilFS path. This should be done from a machine with internet access (like your devserver):
+
+```bash
+# Set HF_HOME to the OilFS path
+export HF_HOME=/mnt/wsfuse/teamforge/hf
+
+# Download the model (replace with your desired model)
+hf download Qwen/Qwen3-8B --local-dir /mnt/wsfuse/teamforge/hf/qwen3_8b
+```
+
+#### 2. Hydrate the HuggingFace Cache
+
+After downloading the weights, you need to hydrate the HuggingFace cache so that the transformers library can find the model metadata:
+
+```bash
+# Set HF_HOME to the OilFS path
+export HF_HOME=/mnt/wsfuse/teamforge/hf
+
+# Hydrate the cache for the model
+python .meta/mast/hydrate_cache.py --model-id Qwen/Qwen3-8B
+```
+
+This ensures that when MAST runs with `HF_HUB_OFFLINE=1`, the transformers library can locate all necessary files from the cache.
+
+### Directory Structure
+
+Both cache and model files are stored under:
+- **Cache**: `/mnt/wsfuse/teamforge/hf` (set via `HF_HOME`)
+- **Model weights**: `/mnt/wsfuse/teamforge/hf/`
+
+
+#### Weights & Biases
+If you are part of the torchforge team on WandB, then WandB will work out of the box; the link can be found in the MAST logs. If you are not part of the torchforge team on WandB, then you will need to set the "WANDB_API_KEY" environment variable to your WandB API key.
diff --git a/apps/mast/__init__.py b/.meta/mast/__init__.py
similarity index 100%
rename from apps/mast/__init__.py
rename to .meta/mast/__init__.py
diff --git a/.meta/mast/client_bootstrap.sh b/.meta/mast/client_bootstrap.sh
new file mode 100755
index 000000000..9b8a704c3
--- /dev/null
+++ b/.meta/mast/client_bootstrap.sh
@@ -0,0 +1,45 @@
+#!/bin/bash
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Bootstrap script for the MAST client role
+# This script sets up the environment and launches the client training script
+
+set -eEx
+
+LIBCUDA="/usr/local/fbcode/platform010/lib/libcuda.so"
+if [ -f "$LIBCUDA" ]; then
+ export LIBCUDA_DIR="${LIBCUDA%/*}"
+ export TRITON_LIBCUDA_PATH="$LIBCUDA_DIR"
+ export LD_PRELOAD="$LIBCUDA:/usr/local/fbcode/platform010/lib/libnvidia-ml.so${PRELOAD_PATH:+:$PRELOAD_PATH}"
+fi
+
+# Also preload put path to torch libs as for monarch dev workflow we dont
+# install it into the env so we need to make sure the binaries can find
+# libtorch and friends on mast and the rpaths set during dev install will
+# be wrong on mast.
+export LD_LIBRARY_PATH="${CONDA_DIR}/lib:${CONDA_DIR}/lib/python3.10/site-packages/torch/lib${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}"
+export PYTHONPATH="${PYTHONPATH:+$PYTHONPATH:}$TORCHX_RUN_PYTHONPATH"
+
+# shellcheck disable=SC1091
+if [ -n "$CONDA_PREFIX" ]; then
+ echo "A conda environment is already activated: $CONDA_DEFAULT_ENV"
+else
+ # Disable command printing to avoid log spew.
+ set +x
+ source "${CONDA_DIR}/bin/activate"
+ # Re-enable command printing after conda activation.
+ set -x
+fi
+
+if [ -z "$WORKSPACE_DIR" ] || [ ! -d "$WORKSPACE_DIR" ]; then
+ WORKSPACE_DIR="$CONDA_PREFIX"
+fi
+
+cd "$WORKSPACE_DIR/torchforge"
+
+# Execute the client training script with all passed arguments
+exec python -X faulthandler .meta/mast/main.py "$@"
diff --git a/apps/mast/env_setup.sh b/.meta/mast/env_setup.sh
similarity index 59%
rename from apps/mast/env_setup.sh
rename to .meta/mast/env_setup.sh
index 8d14371ac..47563ccd1 100755
--- a/apps/mast/env_setup.sh
+++ b/.meta/mast/env_setup.sh
@@ -6,8 +6,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
-# setup_forge_env.sh - Setup conda environment and install forge with mounting
-set -e # Exit on any error
+# Set up conda environment and install forge with mounting
+
+# Configuration
+CONDA_ENV_NAME="forge:latest_conveyor_build"
# Colors for output
RED='\033[0;31m'
@@ -45,6 +47,7 @@ mount_workspace() {
log_info "Creating mount directory: $mount_dir"
sudo mkdir -p "$mount_dir" || {
log_error "Failed to create mount directory (may need sudo privileges)"
+ log_error "You could alternatively try to unmount with `sudo umount /mnt/wsfuse`"
return 1
}
fi
@@ -105,8 +108,6 @@ fi
# Define paths
FBSOURCE_PATH="/data/users/$USER/fbsource"
CONDA_SCRIPT_PATH="$FBSOURCE_PATH/genai/xlformers/dev/xl_conda.sh"
-FORGE_BASE_DIR="/data/users/$USER"
-FORGE_REPO_DIR="$FORGE_BASE_DIR/forge"
# Workspace URL for mounting
WORKSPACE_URL="ws://ws.ai.pci0ai/genai_fair_llm"
@@ -130,66 +131,97 @@ if [ ! -f "$CONDA_SCRIPT_PATH" ]; then
fi
log_info "Sourcing conda script: $CONDA_SCRIPT_PATH"
-source "$CONDA_SCRIPT_PATH" activate forge:e146614
+source "$CONDA_SCRIPT_PATH" activate "$CONDA_ENV_NAME"
if [ $? -ne 0 ]; then
- log_error "Failed to activate conda environment forge-e146614"
+ log_error "Failed to activate conda environment $CONDA_ENV_NAME"
exit 1
fi
log_info "Conda environment activated successfully"
-# Step 3: Create and navigate to forge base directory
-log_info "Step 3: Setting up forge directory..."
-if [ ! -d "$FORGE_BASE_DIR" ]; then
- log_info "Creating forge base directory: $FORGE_BASE_DIR"
- mkdir -p "$FORGE_BASE_DIR"
-fi
-cd "$FORGE_BASE_DIR"
-log_info "Changed to directory: $(pwd)"
+# Step 3: Install torchtitan
+log_info "Step 3: Installing torchtitan..."
+
+# Source versions.sh to get the pinned commit
+VERSIONS_FILE="assets/versions.sh"
+if [ -f "$VERSIONS_FILE" ]; then
+ log_info "Sourcing version information from: $VERSIONS_FILE"
+ source "$VERSIONS_FILE"
-# Step 4: Clone or update forge repository
-log_info "Step 4: Setting up forge git repository..."
-if [ -d "$FORGE_REPO_DIR" ]; then
- log_warn "Forge repository already exists at: $FORGE_REPO_DIR"
- cd "$FORGE_REPO_DIR"
+ if [ -n "$TORCHTITAN_COMMIT_MAST" ]; then
+ log_info "Installing torchtitan from commit: $TORCHTITAN_COMMIT_MAST"
+ pip uninstall -y torchtitan
+ pip install "git+https://github.com/pytorch/torchtitan.git@$TORCHTITAN_COMMIT_MAST"
- if [ -d ".git" ]; then
- log_info "Updating existing repository..."
- git fetch origin
if [ $? -eq 0 ]; then
- log_info "Repository updated successfully"
+ log_info "Torchtitan installed successfully"
else
- log_warn "Failed to fetch updates, continuing with existing code"
+ log_error "Failed to install torchtitan"
+ exit 1
fi
else
- log_error "Directory exists but is not a git repository"
- log_info "Removing directory and cloning fresh..."
- cd "$FORGE_BASE_DIR"
- rm -rf "$FORGE_REPO_DIR"
- git clone git@github.com:meta-pytorch/forge.git
- if [ $? -ne 0 ]; then
- log_error "Failed to clone forge repository"
+ log_error "TORCHTITAN_COMMIT_MAST not found in versions.sh"
+ exit 1
+ fi
+else
+ log_error "versions.sh not found at: $VERSIONS_FILE"
+ log_error "Cannot proceed without version information"
+ exit 1
+fi
+
+# Step 3.5: Apply monarch torch import hack
+log_info "Step 3.5: Applying monarch torch import hack..."
+
+MONARCH_INIT="$CONDA_PREFIX/lib/python3.10/site-packages/monarch/__init__.py"
+if [ -f "$MONARCH_INIT" ]; then
+ # Check if we already applied the hack
+ if grep -q "^import torch # Injected by forge setup" "$MONARCH_INIT"; then
+ log_info "Monarch torch import hack already applied, skipping"
+ else
+ log_info "Injecting 'import torch' into monarch/__init__.py"
+
+ # Create a backup
+ cp "$MONARCH_INIT" "$MONARCH_INIT.bak"
+
+ # Use sed to inject 'import torch' before the "# Import before monarch" comment
+ # We add it right after "from typing import TYPE_CHECKING" and before the comment
+ sed -i '/^from typing import TYPE_CHECKING$/a\
+\
+# Torch must be imported before monarch (injected by forge setup)\
+import torch # Injected by forge setup' "$MONARCH_INIT"
+
+ if [ $? -eq 0 ]; then
+ log_info "Successfully injected torch import into monarch/__init__.py"
+ else
+ log_error "Failed to inject torch import, restoring backup"
+ mv "$MONARCH_INIT.bak" "$MONARCH_INIT"
exit 1
fi
- cd "$FORGE_REPO_DIR"
fi
else
- log_info "Cloning forge repository..."
- git clone git@github.com:meta-pytorch/forge.git
- if [ $? -ne 0 ]; then
- log_error "Failed to clone forge repository"
- log_error "Please ensure:"
- log_error "1. You have SSH access to github.com"
- log_error "2. Your SSH key is added to GitHub"
- log_error "3. You have access to meta-pytorch/forge repository"
- exit 1
- fi
- cd "$FORGE_REPO_DIR"
+ log_warn "monarch/__init__.py not found at: $MONARCH_INIT"
+ log_warn "Skipping monarch torch import hack (monarch may not be installed yet)"
fi
-log_info "Current directory: $(pwd)"
+# Step 4: Check for existing build directory and warn user
+log_info "Step 4: Checking for existing build directory..."
+if [ -d "build" ]; then
+ log_warn "Detected existing build/ directory at: $(pwd)/build"
+ log_warn "This directory may contain artifacts from a previous pip installation"
+ log_warn "that could interfere with the current installation."
+ log_warn "If you encounter issues, manually remove it with: rm -rf build"
+ echo ""
+ read -p "$(echo -e ${YELLOW}Do you want to continue anyway? [y/N]:${NC} )" -n 1 -r
+ echo ""
+ if [[ ! $REPLY =~ ^[Yy]$ ]]; then
+ log_info "Installation cancelled by user"
+ log_info "You can manually remove the build/ directory with: rm -rf build"
+ exit 0
+ fi
+ log_warn "Continuing with existing build/ directory. Things might go wrong!"
+fi
# Step 5: Install forge package
log_info "Step 5: Installing forge package..."
@@ -230,9 +262,13 @@ pip list | grep -E "(forge|monarch)" || log_warn "No forge/monarch packages foun
log_info "Environment setup complete! You can now run your scripts."
log_info "Mounted workspace available at: /mnt/wsfuse"
-# Step 6: Ask user to deactivate and activate conda env conda environment
+log_info "Unsetting CUDA_HOME and overwriting the LD_LIBRARY_PATH"
+unset CUDA_HOME
+export LD_LIBRARY_PATH=${CONDA_PREFIX}/lib
+
+# Step 6: Ask user to test
echo ""
log_info "Installation completed successfully!"
echo ""
-log_info "Re-activate the conda environment to make the changes take effect:"
-log_info "conda deactivate && conda activate forge-e146614"
+log_info "Test that this is working locally with:"
+log_info "python -m apps.grpo.main --config=apps/grpo/qwen3_1_7b.yaml"
diff --git a/.meta/mast/hydrate_cache.py b/.meta/mast/hydrate_cache.py
new file mode 100644
index 000000000..c99357eca
--- /dev/null
+++ b/.meta/mast/hydrate_cache.py
@@ -0,0 +1,57 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""This is convenience script meant for hydrating the HuggingFace cache.
+
+This is meant for downloading the model weights and tokenizer to the cache, i.e. for
+OilFS.
+
+Example:
+
+python .meta/mast/hydrate_cache.py --model-id Qwen/Qwen3-32B
+
+"""
+
+import argparse
+import os
+import sys
+
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Hydrate HuggingFace cache for a specific model"
+ )
+ parser.add_argument(
+ "--model-id",
+ type=str,
+ required=True,
+ help="HuggingFace model ID (e.g., Qwen/Qwen3-8B)",
+ )
+ args = parser.parse_args()
+
+ # Ensure HF_HOME is set
+ hf_home = os.environ.get("HF_HOME")
+ if not hf_home:
+ print(
+ "ERROR: HF_HOME environment variable must be set. "
+ "You will likely want to run export HF_HOME=/mnt/wsfuse/teamforge/hf."
+ )
+ sys.exit(1)
+
+ print(f"Using HF_HOME: {hf_home}")
+ print(f"Downloading {args.model_id}...")
+
+ # This will pull tokenizer + config + all weight shards
+ tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True)
+ model = AutoModelForCausalLM.from_pretrained(args.model_id, trust_remote_code=True)
+
+ print("Download complete. Cache hydrated.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/.meta/mast/launch.sh b/.meta/mast/launch.sh
new file mode 100755
index 000000000..aa28003fe
--- /dev/null
+++ b/.meta/mast/launch.sh
@@ -0,0 +1,79 @@
+#!/bin/bash
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# launch.sh - Launch MAST jobs with Forge
+set -e # Exit on any error
+
+# Colors for output
+RED='\033[0;31m'
+GREEN='\033[0;32m'
+YELLOW='\033[1;33m'
+NC='\033[0m' # No Color
+
+# Logging functions
+log_info() {
+ echo -e "${GREEN}[INFO]${NC} $1"
+}
+
+log_error() {
+ echo -e "${RED}[ERROR]${NC} $1"
+}
+
+# Check if config file is provided
+if [ $# -eq 0 ]; then
+ log_error "No config file provided"
+ echo "Usage: $0 "
+ echo "Example: $0 .meta/mast/qwen3_1_7b_mast.yaml"
+ exit 1
+fi
+
+CONFIG_FILE="$1"
+
+# Generate a unique job name based on the config file name
+BASENAME=$(basename "$CONFIG_FILE" .yaml)
+RANDOM_SUFFIX=$(cat /dev/urandom | tr -dc 'a-z0-9' | fold -w 6 | head -n 1)
+JOB_NAME="${BASENAME}-${RANDOM_SUFFIX}"
+log_info "Generated job name: $JOB_NAME"
+
+# Get the directory where this script is located
+SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
+
+# Navigate to forge root (two levels up from .meta/mast/)
+FORGE_ROOT="$( cd "$SCRIPT_DIR/../.." && pwd )"
+
+log_info "Forge root directory: $FORGE_ROOT"
+log_info "Config file: $CONFIG_FILE"
+
+# Check if config file exists
+if [ ! -f "$FORGE_ROOT/$CONFIG_FILE" ]; then
+ log_error "Config file not found: $FORGE_ROOT/$CONFIG_FILE"
+ exit 1
+fi
+
+# Navigate to forge root
+cd "$FORGE_ROOT"
+log_info "Changed to directory: $(pwd)"
+
+# Reinstall forge package
+log_info "Reinstalling forge package..."
+pip install --force-reinstall --no-deps .
+if [ $? -ne 0 ]; then
+ log_error "Failed to reinstall forge package"
+ exit 1
+fi
+
+log_info "Successfully reinstalled forge package"
+
+# Launch the job
+CHECKPOINT_FOLDER=/mnt/wsfuse/teamforge/forge_runs/$JOB_NAME
+log_info "Launching MAST job..."
+
+# Manually override the relevant checkpoint path(s)
+# This unfortunately cannot be done in the YAML itself since this should be
+# based on job name...
+PYTHONPATH=. python .meta/mast/main.py --job-name $JOB_NAME --config $CONFIG_FILE trainer.checkpoint.folder=${CHECKPOINT_FOLDER} trainer.dcp_path=${CHECKPOINT_FOLDER}
diff --git a/.meta/mast/main.py b/.meta/mast/main.py
new file mode 100644
index 000000000..d901d1ab5
--- /dev/null
+++ b/.meta/mast/main.py
@@ -0,0 +1,141 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import asyncio
+import os
+import sys
+
+from apps.grpo.main import main as grpo_main
+from forge.controller.launcher import (
+ JOB_NAME_KEY,
+ LAUNCHER_KEY,
+ MastLauncher,
+ mount_mnt_directory,
+)
+from forge.controller.provisioner import init_provisioner
+
+from forge.types import (
+ Launcher,
+ LauncherConfig,
+ ProcessConfig,
+ ProvisionerConfig,
+ ServiceConfig,
+)
+from forge.util.config import parse
+from omegaconf import DictConfig
+
+DEFAULT_CHECKPOINT_FOLDER_KEY = "checkpoint_folder"
+DEFAULT_CHECKPOINT_FOLDER = "/mnt/wsfuse/teamforge/forge_runs/"
+
+
+def setup_wandb_api_key() -> None:
+ # add wandb API key to the environment
+ if "WANDB_API_KEY" in os.environ:
+ print("[wandb] WANDB_API_KEY already set in environment.")
+ return
+ secret_name = "TORCHFORGE_WANDB_API_KEY"
+ print(f"[wandb] Attempting to retrieve API key from keychain {secret_name=}")
+ try:
+ import base64
+
+ from cif import client
+
+ response = client.request(
+ "keychain.service",
+ "getSecretV2",
+ {
+ "request": {
+ "name": secret_name,
+ }
+ },
+ )
+ # decode base64 encoded string
+ wandb_api_key = base64.b64decode(
+ # pyrefly: ignore [bad-index]
+ response["result"]["secret"]["value"]
+ ).decode("utf-8")
+ print("[wandb] Successfully retrieved API key from keychain.")
+ os.environ["WANDB_API_KEY"] = wandb_api_key
+ os.environ["WANDB_BASE_URL"] = "https://meta.wandb.io/"
+ except Exception as keychain_exception:
+ print(
+ f"[wandb] Failed to retrieve API key from keychain. {keychain_exception=}"
+ )
+ raise RuntimeError(
+ "Failed to retrieve wandb API key. Cannot launch job"
+ ) from keychain_exception
+
+
+async def main(cfg: DictConfig, mode: str = "detached", extra_args: list = None):
+ """Main module for launching mast jobs for GRPO training.
+
+ Args:
+ cfg: Configuration dictionary
+ mode: "detached" (default) launches MAST job with client in MAST,
+ "remote" runs training directly (used when client runs in MAST)
+ extra_args: Additional CLI arguments to pass through to the client
+ """
+ if cfg.get(LAUNCHER_KEY, Launcher.MAST.value) != Launcher.MAST.value:
+ raise ValueError("Launcher must be MAST.")
+
+ # Job name should already be set from CLI args in __main__ section
+ # No need to modify it further here
+ if cfg.get(JOB_NAME_KEY, None) is None:
+ raise ValueError("Job name is required but not provided")
+
+ launcher_config = LauncherConfig(
+ launcher=Launcher(cfg.get(LAUNCHER_KEY, Launcher.MAST.value)),
+ job_name=cfg.get(JOB_NAME_KEY, None),
+ services={k: ServiceConfig(**v) for k, v in cfg.services.items()},
+ actors={k: ProcessConfig(**v) for k, v in cfg.actors.items()},
+ )
+
+ if mode == "detached":
+ # In detached mode, just launch the MAST job with client role included
+ launcher = MastLauncher(
+ launcher_config,
+ detached=True,
+ extra_args=extra_args or [],
+ )
+ await launcher.launch_mast_job()
+ else:
+ # In remote mode, we're already running inside MAST, so set up wandb api key, mount directory,
+ # init provisioner and run training
+ setup_wandb_api_key()
+ mount_mnt_directory("/mnt/wsfuse")
+ await init_provisioner(ProvisionerConfig(launcher_config=launcher_config))
+ await grpo_main(cfg)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--mode",
+ type=str,
+ default="detached",
+ choices=["detached", "remote"],
+ help="Run mode: 'detached' for launching MAST job with client in MAST, 'remote' for running training directly",
+ )
+ parser.add_argument(
+ "--job-name",
+ type=str,
+ default=None,
+ help="MAST job name (required - generated by launch.sh)",
+ )
+ args, remaining = parser.parse_known_args()
+
+ # Replace sys.argv with remaining args so @parse can work
+ sys.argv = [sys.argv[0]] + remaining
+
+ @parse
+ def _main(cfg):
+ # Override job name from CLI
+ if args.job_name:
+ cfg[JOB_NAME_KEY] = args.job_name
+ asyncio.run(main(cfg, mode=args.mode, extra_args=remaining))
+
+ _main() # @parse grabs the cfg from CLI
diff --git a/apps/mast/qwen3_1_7b_mast.yaml b/.meta/mast/qwen3_1_7b_mast.yaml
similarity index 68%
rename from apps/mast/qwen3_1_7b_mast.yaml
rename to .meta/mast/qwen3_1_7b_mast.yaml
index 58d879579..27f434def 100644
--- a/apps/mast/qwen3_1_7b_mast.yaml
+++ b/.meta/mast/qwen3_1_7b_mast.yaml
@@ -1,16 +1,15 @@
# Grouped Relative Policy Optimization (GRPO)
-# >>> python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
+# >>> ./.meta/mast/launch.sh .meta/mast/qwen3_1_7b_mast.yaml
# Global configuration
group_size: 8
-batch_size: 16
+local_batch_size: 16 # per-device batch size
max_req_tokens: 512
max_res_tokens: 512
-model: "Qwen/Qwen3-1.7B"
+model: "/mnt/wsfuse/teamforge/hf/qwen3_1.7b"
off_by_n: 1 # Off by one by default
launcher: mast
-job_name: forge-qwen3-1_7b
-checkpoint_folder: /mnt/wsfuse/teamforge/forge_runs/
+compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
# Main loop configuration
rollout_threads: ${services.policy.num_replicas} # Recommended to set equal to policy.num_replicas
@@ -18,15 +17,16 @@ rollout_threads: ${services.policy.num_replicas} # Recommended to set equal to
# Observability configuration
metric_logging:
wandb:
+ entity: torchforge
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
- reduce_across_ranks: True
+ logging_mode: global_reduce
console:
- reduce_across_ranks: True
+ logging_mode: global_reduce
# Dataset configuration
dataset:
- path: "openai/gsm8k"
+ path: /mnt/wsfuse/teamforge/hf/gsm8k
revision: "main"
data_split: "train"
streaming: true
@@ -34,15 +34,15 @@ dataset:
# Policy configuration
policy:
- engine_config:
- model: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-1.7B/snapshots/0060bc56d46589041c1048efd1a397421b1142b5
+ engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
+ model: /mnt/wsfuse/teamforge/hf/qwen3_1.7b
tensor_parallel_size: 1
pipeline_parallel_size: 1
- enforce_eager: false
+ enforce_eager: ${not:${compile}}
# TODO: Had to disable this becasue vLLm wouldn't like
# needs to revisited.
disable_custom_all_reduce: true
- sampling_config:
+ sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
@@ -53,7 +53,8 @@ trainer:
model:
name: qwen3
flavor: 1.7B
- hf_assets_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-1.7B/snapshots/0060bc56d46589041c1048efd1a397421b1142b5
+ hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_1.7b
+ # hf_assets_path: hf://${model}
optimizer:
name: AdamW
lr: 1e-5
@@ -61,14 +62,14 @@ trainer:
lr_scheduler:
warmup_steps: 1
training:
- local_batch_size: ${batch_size}
- seq_len: 2048
+ local_batch_size: ${local_batch_size}
+ seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1000000
dtype: bfloat16
gc_freq: 1
compile:
- enable: false
+ enable: ${compile}
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
@@ -79,8 +80,9 @@ trainer:
disable_loss_parallel: true
checkpoint:
enable: true
- initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-1.7B/snapshots/0060bc56d46589041c1048efd1a397421b1142b5
+ initial_load_path: /mnt/wsfuse/teamforge/hf/qwen3_1.7b
initial_load_in_hf: true
+ folder: ${checkpoint_folder}
last_save_in_hf: true
interval: 500
async_mode: "disabled"
@@ -95,7 +97,7 @@ trainer:
# Replay buffer configuration
replay_buffer:
- batch_size: ${batch_size}
+ batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree
@@ -104,12 +106,14 @@ ref_model:
model:
name: qwen3
flavor: 1.7B
- hf_assets_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-1.7B/snapshots/0060bc56d46589041c1048efd1a397421b1142b5
+ hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_1.7b
+ # hf_assets_path: hf://${model}
training:
+ seq_len: ${trainer.training.seq_len}
dtype: bfloat16
gc_freq: 1
compile:
- enable: false
+ enable: ${compile}
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
@@ -119,20 +123,21 @@ ref_model:
expert_parallel_degree: 1
checkpoint:
enable: true
- initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-1.7B/snapshots/0060bc56d46589041c1048efd1a397421b1142b5
+ initial_load_path: /mnt/wsfuse/teamforge/hf/qwen3_1.7b
+ folder: ""
initial_load_in_hf: true
# All resource allocations
services:
policy:
- procs: ${policy.engine_config.tensor_parallel_size}
- num_replicas: 2
+ procs: ${policy.engine_args.tensor_parallel_size}
+ num_replicas: 1
with_gpus: true
mesh_name: policy
hosts: 1
ref_model:
procs: 1
- num_replicas: 2
+ num_replicas: 1
with_gpus: true
mesh_name: ref_model
hosts: 1
diff --git a/apps/mast/qwen3_32b_mast.yaml b/.meta/mast/qwen3_32b_mast.yaml
similarity index 64%
rename from apps/mast/qwen3_32b_mast.yaml
rename to .meta/mast/qwen3_32b_mast.yaml
index 0db8f4af3..9a41b9f9f 100644
--- a/apps/mast/qwen3_32b_mast.yaml
+++ b/.meta/mast/qwen3_32b_mast.yaml
@@ -1,16 +1,15 @@
# Grouped Relative Policy Optimization (GRPO)
-# >>> python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
+# >>> ./.meta/mast/launch.sh .meta/mast/qwen3_32b_mast.yaml
# Global configuration
group_size: 8
-batch_size: 16
+local_batch_size: 16 # per-device batch size
max_req_tokens: 512
max_res_tokens: 512
-model: "Qwen/Qwen3-32B"
+model: "/mnt/wsfuse/teamforge/hf/qwen3_32b"
off_by_n: 1 # Off by one by default
launcher: mast
-job_name: forge-qwen3-32b
-checkpoint_folder: /mnt/wsfuse/teamforge/forge_runs/
+compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
# Main loop configuration
rollout_threads: ${services.policy.num_replicas} # Recommended to set equal to policy.num_replicas
@@ -18,15 +17,16 @@ rollout_threads: ${services.policy.num_replicas} # Recommended to set equal to
# Observability configuration
metric_logging:
wandb:
+ entity: torchforge
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
- reduce_across_ranks: True
+ logging_mode: global_reduce
console:
- reduce_across_ranks: True
+ logging_mode: global_reduce
# Dataset configuration
dataset:
- path: "openai/gsm8k"
+ path: /mnt/wsfuse/teamforge/hf/gsm8k
revision: "main"
data_split: "train"
streaming: true
@@ -34,15 +34,15 @@ dataset:
# Policy configuration
policy:
- engine_config:
- model: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-32B/snapshots/d47b0d4ae4b48fde975756bf360a63a9cca8d470
+ engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
+ model: /mnt/wsfuse/teamforge/hf/qwen3_32b
tensor_parallel_size: 2
pipeline_parallel_size: 1
- enforce_eager: false
+ enforce_eager: ${not:${compile}}
# TODO: Had to disable this becasue vLLm wouldn't like
# needs to revisited.
disable_custom_all_reduce: true
- sampling_config:
+ sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
@@ -53,7 +53,7 @@ trainer:
model:
name: qwen3
flavor: 32B
- hf_assets_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-32B/snapshots/d47b0d4ae4b48fde975756bf360a63a9cca8d470
+ hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_32b
optimizer:
name: AdamW
lr: 1e-5
@@ -61,31 +61,32 @@ trainer:
lr_scheduler:
warmup_steps: 1
training:
- local_batch_size: ${batch_size}
- seq_len: 2048
+ local_batch_size: ${local_batch_size}
+ seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1000000
dtype: bfloat16
gc_freq: 1
compile:
- enable: false
+ enable: ${compile}
parallelism:
data_parallel_replicate_degree: 1
- data_parallel_shard_degree: 4
- tensor_parallel_degree: 2
+ data_parallel_shard_degree: 8
+ tensor_parallel_degree: 1
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
disable_loss_parallel: true
checkpoint:
enable: true
- initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-32B/snapshots/d47b0d4ae4b48fde975756bf360a63a9cca8d470
+ initial_load_path: /mnt/wsfuse/teamforge/hf/qwen3_32b
initial_load_in_hf: true
+ folder: ${checkpoint_folder}
last_save_in_hf: true
interval: 500
async_mode: "disabled"
activation_checkpoint:
- mode: selective
+ mode: full
selective_ac_option: op
comm:
# TODO: needs to be revisited. causing NCCL timeouts on inits when loading CP
@@ -95,7 +96,7 @@ trainer:
# Replay buffer configuration
replay_buffer:
- batch_size: ${batch_size}
+ batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree
@@ -104,34 +105,39 @@ ref_model:
model:
name: qwen3
flavor: 32B
- hf_assets_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-32B/snapshots/d47b0d4ae4b48fde975756bf360a63a9cca8d470
+ hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_32b
training:
+ seq_len: ${trainer.training.seq_len}
dtype: bfloat16
gc_freq: 1
compile:
- enable: false
+ enable: ${compile}
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
- tensor_parallel_degree: 2
+ tensor_parallel_degree: 4
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
checkpoint:
enable: true
- initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-32B/snapshots/d47b0d4ae4b48fde975756bf360a63a9cca8d470
+ initial_load_path: /mnt/wsfuse/teamforge/hf/qwen3_32b
+ folder: ""
initial_load_in_hf: true
-
+ comm:
+ # TODO: needs to be revisited. causing NCCL timeouts on inits when loading CP
+ # from oilfs if the traienr is not in the same region as in oilfs
+ init_timeout_seconds: 1200
# All resource allocations
services:
policy:
- procs: ${policy.engine_config.tensor_parallel_size}
+ procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 2
with_gpus: true
mesh_name: policy
hosts: 1
ref_model:
- procs: 1
+ procs: 4
num_replicas: 2
with_gpus: true
mesh_name: ref_model
@@ -141,7 +147,6 @@ services:
num_replicas: 1
with_gpus: false
mesh_name: reward_actor
-
actors:
dataset:
procs: 1
diff --git a/apps/mast/qwen3_4b_mast.yaml b/.meta/mast/qwen3_4b_mast.yaml
similarity index 69%
rename from apps/mast/qwen3_4b_mast.yaml
rename to .meta/mast/qwen3_4b_mast.yaml
index 92119055a..88e6dbfc9 100644
--- a/apps/mast/qwen3_4b_mast.yaml
+++ b/.meta/mast/qwen3_4b_mast.yaml
@@ -1,16 +1,15 @@
# Grouped Relative Policy Optimization (GRPO)
-# >>> python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
+# >>> ./.meta/mast/launch.sh .meta/mast/qwen3_4b_mast.yaml
# Global configuration
group_size: 8
-batch_size: 16
+local_batch_size: 16 # per-device batch size
max_req_tokens: 512
max_res_tokens: 512
model: "Qwen/Qwen3-4B"
off_by_n: 1 # Off by one by default
launcher: mast
-job_name: forge-qwen3-4b
-checkpoint_folder: /mnt/wsfuse/teamforge/forge_runs/
+compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
# Main loop configuration
rollout_threads: ${services.policy.num_replicas} # Recommended to set equal to policy.num_replicas
@@ -18,15 +17,16 @@ rollout_threads: ${services.policy.num_replicas} # Recommended to set equal to
# Observability configuration
metric_logging:
wandb:
+ entity: torchforge
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
- reduce_across_ranks: True
+ logging_mode: global_reduce
console:
- reduce_across_ranks: True
+ logging_mode: global_reduce
# Dataset configuration
dataset:
- path: "openai/gsm8k"
+ path: /mnt/wsfuse/teamforge/hf/gsm8k
revision: "main"
data_split: "train"
streaming: true
@@ -34,15 +34,15 @@ dataset:
# Policy configuration
policy:
- engine_config:
- model: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-4B-Base/snapshots/a81b894c2624d21c88a3ad737ce4f837424b7eed
+ engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
+ model: /mnt/wsfuse/teamforge/hf/qwen3_4b
tensor_parallel_size: 2
pipeline_parallel_size: 1
- enforce_eager: false
+ enforce_eager: ${not:${compile}}
# TODO: Had to disable this becasue vLLm wouldn't like
# needs to revisited.
disable_custom_all_reduce: true
- sampling_config:
+ sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
@@ -53,7 +53,8 @@ trainer:
model:
name: qwen3
flavor: 4B
- hf_assets_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-4B-Base/snapshots/a81b894c2624d21c88a3ad737ce4f837424b7eed
+ hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_4b
+ # hf_assets_path: hf://${model}
optimizer:
name: AdamW
lr: 1e-5
@@ -61,14 +62,14 @@ trainer:
lr_scheduler:
warmup_steps: 1
training:
- local_batch_size: ${batch_size}
- seq_len: 2048
+ local_batch_size: ${local_batch_size}
+ seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1000000
dtype: bfloat16
gc_freq: 1
compile:
- enable: false
+ enable: ${compile}
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 4
@@ -79,8 +80,9 @@ trainer:
disable_loss_parallel: true
checkpoint:
enable: true
- initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-4B-Base/snapshots/a81b894c2624d21c88a3ad737ce4f837424b7eed
+ initial_load_path: /mnt/wsfuse/teamforge/hf/qwen3_4b
initial_load_in_hf: true
+ folder: ${checkpoint_folder}
last_save_in_hf: true
interval: 500
async_mode: "disabled"
@@ -95,7 +97,7 @@ trainer:
# Replay buffer configuration
replay_buffer:
- batch_size: ${batch_size}
+ batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree
@@ -104,12 +106,14 @@ ref_model:
model:
name: qwen3
flavor: 4B
- hf_assets_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-4B-Base/snapshots/a81b894c2624d21c88a3ad737ce4f837424b7eed
+ hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_4b
+ # hf_assets_path: hf://${model}
training:
+ seq_len: ${trainer.training.seq_len}
dtype: bfloat16
gc_freq: 1
compile:
- enable: false
+ enable: ${compile}
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
@@ -119,13 +123,14 @@ ref_model:
expert_parallel_degree: 1
checkpoint:
enable: true
- initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-4B-Base/snapshots/a81b894c2624d21c88a3ad737ce4f837424b7eed
+ initial_load_path: /mnt/wsfuse/teamforge/hf/qwen3_4b
+ folder: ""
initial_load_in_hf: true
# All resource allocations
services:
policy:
- procs: ${policy.engine_config.tensor_parallel_size}
+ procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 2
with_gpus: true
mesh_name: policy
@@ -144,7 +149,7 @@ services:
actors:
dataset:
- procs: 8
+ procs: 1
with_gpus: false
mesh_name: dataset
trainer:
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 1a9bad2ca..eb2104e2b 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -39,11 +39,17 @@ repos:
hooks:
- id: ufmt
additional_dependencies:
- - black == 22.12.0
- - usort == 1.0.5
+ - black == 24.4.2
+ - usort == 1.0.8.post1
- repo: https://github.com/jsh9/pydoclint
rev: 0.5.12
hooks:
- id: pydoclint
args: [--config=pyproject.toml]
+
+- repo: https://github.com/fastai/nbdev.git
+ rev: 2.4.5
+ hooks:
+ - id: nbdev_clean
+ args: [--clear_all]
diff --git a/README.md b/README.md
index 5142b91a0..265aea2d6 100644
--- a/README.md
+++ b/README.md
@@ -1,26 +1,28 @@
-#
Forge
-
+#
torchforge
#### A PyTorch-native agentic RL library that lets you focus on algorithms—not infra.
+[](https://github.com/meta-pytorch/forge/actions/workflows/gpu_test.yaml?query=branch%3Amain)
+[](https://meta-pytorch.org/torchforge/)
+[](https://discord.gg/YsTYBh6PD9)
## Overview
-The primary purpose of the Forge ecosystem is to delineate infra concerns from model concerns thereby making RL experimentation easier. Forge delivers this by providing clear RL abstractions and one scalable implementation of these abstractions. When you need fine-grained control over placement, fault handling/redirecting training loads during a run, or communication patterns, the primitives are there. When you don’t, you can focus purely on your RL algorithm.
+The primary purpose of the torchforge ecosystem is to separate infra concerns from model concerns thereby making RL experimentation easier. torchforge delivers this by providing clear RL abstractions and one scalable implementation of these abstractions. When you need fine-grained control over placement, fault handling/redirecting training loads during a run, or communication patterns, the primitives are there. When you don’t, you can focus purely on your RL algorithm.
Key features:
- Usability for rapid research (isolating the RL loop from infrastructure)
- Hackability for power users (all parts of the RL loop can be easily modified without interacting with infrastructure)
- Scalability (ability to shift between async and synchronous training and across thousands of GPUs)
-> ⚠️ **Early Development Warning** Forge is currently in an experimental
+> ⚠️ **Early Development Warning** torchforge is currently in an experimental
> stage. You should expect bugs, incomplete features, and APIs that may change
> in future versions. The project welcomes bugfixes, but to make sure things are
> well coordinated you should discuss any significant change before starting the
> work. It's recommended that you signal your intention to contribute in the
> issue tracker, either by filing a new issue or by claiming an existing one.
-## 📖 Documentation (Coming Soon)
+## 📖 Documentation
-View Forge's hosted documentation (coming soon)
+View torchforge's hosted documentation: https://meta-pytorch.org/torchforge.
## Tutorials
@@ -28,45 +30,34 @@ You can also find our notebook tutorials (coming soon)
## Installation
-### Basic
-
-Forge requires the latest PyTorch nightly with [Monarch](https://github.com/meta-pytorch/monarch), [vLLM](https://github.com/vllm-project/vllm), and [torchtitan](https://github.com/pytorch/torchtitan). For convenience,
-we have pre-packaged these dependencies as wheels in assets/wheels. (Note that the basic install script
-uses [DNF](https://docs.fedoraproject.org/en-US/quick-docs/dnf/), but could be easily extended to other Linux OS.)
+torchforge requires PyTorch 2.9.0 with [Monarch](https://github.com/meta-pytorch/monarch), [vLLM](https://github.com/vllm-project/vllm), and [torchtitan](https://github.com/pytorch/torchtitan).
-Forge requires the Github CLI (gh) to download a compatible vLLM package. See [here](https://github.com/cli/cli#installation) for gh install instructions before continuting. Please login to gh with your Github account before continuing with `gh auth login`. You may use either https or ssh as the protocol for authentication.
+Install torchforge with:
```bash
-conda create -n forge python=3.10
+conda create -n forge python=3.12
conda activate forge
./scripts/install.sh
```
-Optional: By default, the packages installation uses conda. If user wants to install system packages on the target machine instead of conda, they can pass the `--use-sudo` to the installation script: `./script/install.sh --use-sudo`.
+The install script installs system dependencies along with torchforge. Note that this install script uses [DNF](https://docs.fedoraproject.org/en-US/quick-docs/dnf/), but could be easily extended to other Linux OS.
-After install, you can run the following command and should see output confirming GRPO training is running (you need a minimum 3 GPU devices):
+Optional: By default, the packages installation uses conda. If you want to install system packages on the target machine instead of conda, you can pass the `--use-sudo` flag to the installation script: `./scripts/install.sh --use-sudo`.
-```
-python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
-```
+> **Note:** We are actively working on enabling pure `uv` installation. Currently, Conda is the recommended approach. `uv` support is not fully working at the moment but is being tracked in [issue #494](https://github.com/meta-pytorch/torchforge/issues/494).
-If you need to re-build the wheels for whatever reason, you can do so with:
-```bash
-./scripts/build_wheels.sh
-```
+After install, you can run the following command and should see output confirming GRPO training is running (you need a minimum 3 GPU devices):
-For your information, since the vLLM wheel is too large for GitHub, we uploaded it as a release in the `install.sh` script:
```
-$ gh release create v0.0.0 assets/wheels/vllm-*.whl --title "Forge Wheels v0.0.0"
+python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
```
## Quick Start
-To run SFT for Llama3 8B, run
+To run SFT on a Llama3 8B model, run
```bash
-uv run forge download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
-uv run forge run --nproc_per_node 2 apps/sft/main.py --config apps/sft/llama3_8b.yaml
+python -m apps.sft.main --config apps/sft/llama3_8b.yaml
```
### Citation
diff --git a/apps/grpo/README.md b/apps/grpo/README.md
new file mode 100644
index 000000000..750aeb095
--- /dev/null
+++ b/apps/grpo/README.md
@@ -0,0 +1,64 @@
+Last updated: 2025-12-05
+
+# GRPO on GSM8K
+
+Training GRPO (Grouped Relative Policy Optimization) on GSM8K grade school math word problems.
+
+## Dataset
+
+GSM8K consists of grade school math word problems that require multi-step arithmetic reasoning. Models generate solutions with chain-of-thought reasoning and provide final answers in `\boxed{}` format.
+
+**Example Input:**
+```
+<|begin_of_text|><|start_header_id|>system<|end_header_id|>
+
+Cutting Knowledge Date: December 2023
+Today Date: 26 Jul 2024
+
+Put all your scratchpad work between and tags.
+ Your final answer should be between and tags otherwise it will not be scored.<|eot_id|><|start_header_id|>user<|end_header_id|>
+
+Lucille is painting her room. Two of her walls are 3 meters wide and 2 meters tall. The third wall is 5 meters wide and 2 meters tall. The final wall is 4 meters wide and 2 meters tall. If each can of paint covers 2 square meters, how many cans of paint does Lucille need?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
+```
+
+**Example Output:**
+```
+
+First, let's calculate the area of each wall:
+
+- Two walls of 3 meters x 2 meters = 3 x 2 = 6 square meters per wall. Since there are two of these, 6 * 2 = 12 square meters.
+- The third wall is 5 meters x 2 meters = 10 square meters.
+- The final wall is 4 meters x 2 meters = 8 square meters.
+Total wall area = 12 + 10 + 8 = 30 square meters.
+
+Since each can covers 2 square meters, we need to divide the total wall area by the area one can covers: 30 / 2 = 15.
+
+
+
+15
+```
+
+## Quick Start
+
+**Llama 3.1 8B** (recommended for learning, requires 5 GPUs as is, not optimized):
+```bash
+python -m apps.grpo.main --config apps/grpo/llama3_8b.yaml
+```
+
+**Qwen3 1.7B** (NOTE: Qwen3 is already saturated on GSM8K, so rewards will **not** increase. Requires 3 GPUs, not optimized):
+```bash
+python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
+```
+
+## Expected Results
+
+For **Llama 3.1 8B**, training rewards should rise above 0.8 within the first few steps as the model learns the task.
+
+
+
+## Configurations
+
+- `llama3_8b.yaml` - Meta Llama 3.1 8B Instruct
+- `qwen3_1_7b.yaml` - Qwen3 1.7B
+- `qwen3_8b.yaml` - Qwen3 8B
+- `qwen3_32b.yaml` - Qwen3 32B
diff --git a/apps/grpo/data.py b/apps/grpo/data.py
new file mode 100644
index 000000000..6d0ebb94e
--- /dev/null
+++ b/apps/grpo/data.py
@@ -0,0 +1,84 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass
+
+from datasets import load_dataset
+from forge.controller.actor import ForgeActor
+from forge.observability.metrics import record_metric, Reduce
+from monarch.actor import endpoint
+from vllm.transformers_utils.tokenizer import get_tokenizer
+
+
+@dataclass
+class DatasetActor(ForgeActor):
+ """Actor wrapper for HuggingFace dataset to provide async interface."""
+
+ path: str = "openai/gsm8k"
+ revision: str = "main"
+ data_split: str = "train"
+ streaming: bool = True
+ model: str = ""
+ seed: int = 42
+
+ @endpoint
+ async def setup(self):
+ self._tokenizer = get_tokenizer(self.model)
+ self._epoch = 0
+
+ def gsm8k_transform(sample):
+ system_prompt = """
+ Put all your scratchpad work between and tags.
+ Your final answer should be between and tags otherwise it will not be scored.
+ """
+ request: str = sample["question"]
+ as_chat = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": request},
+ ]
+ formatted_request = self._tokenizer.apply_chat_template(
+ as_chat,
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+ target: str = sample["answer"]
+ formatted_target = target.split("#### ")[1]
+ return {"request": formatted_request, "target": formatted_target}
+
+ self._base_dataset = load_dataset(
+ self.path, self.revision, split=self.data_split, streaming=self.streaming
+ )
+ self._base_dataset = self._base_dataset.map(gsm8k_transform)
+ self._base_dataset = self._base_dataset.shuffle(seed=self.seed)
+ self._base_dataset.set_epoch(self._epoch)
+ self._iterator = iter(self._base_dataset)
+
+ @endpoint
+ async def sample(self) -> dict[str, str] | None:
+ try:
+ sample = next(self._iterator)
+
+ record_metric("dataset/sample/current_epoch", self._epoch, Reduce.MAX)
+
+ return sample
+ except StopIteration:
+ # Restart iterator for next epoch with reshuffling
+ self._epoch += 1
+ print(
+ f"Dataset epoch {self._epoch - 1} completed. Starting epoch {self._epoch}"
+ )
+ self._base_dataset.set_epoch(self._epoch)
+ self._iterator = iter(self._base_dataset)
+ return next(self._iterator)
+
+ @endpoint
+ async def pad_token(self):
+ # Use pad_token_id if available, otherwise use eos_token_id
+ # Llama models don't have a pad token by default
+ if self._tokenizer.pad_token_id is not None:
+ return self._tokenizer.pad_token_id
+ else:
+ return self._tokenizer.eos_token_id
diff --git a/src/forge/data/rewards.py b/apps/grpo/grading.py
similarity index 77%
rename from src/forge/data/rewards.py
rename to apps/grpo/grading.py
index 29a86fc3a..26b6a8e29 100644
--- a/src/forge/data/rewards.py
+++ b/apps/grpo/grading.py
@@ -6,10 +6,8 @@
import re
-from forge.interfaces import Reward
-
-class MathReward(Reward):
+class MathReward:
"""Reward class for evaluating math correctness."""
def __init__(self, tolerance: float = 1e-6, partial_credit: float = 0.1):
@@ -58,16 +56,29 @@ def _to_float(self, text: str) -> float | None:
return None
-class ThinkingReward(Reward):
- """Reward class for evaluating use of tags in reasoning."""
+class ThinkingReward:
+ """Reward class for evaluating use of thinking tags in reasoning.
+
+ Args:
+ partial_reward: Reward for partial tag usage (incomplete/malformed)
+ full_reward: Reward for well-formed thinking blocks with content
+ tag: Tag name to use (default "think", can use "思考" for Japanese, etc.)
+ """
- def __init__(self, partial_reward: float = 0.2, full_reward: float = 1.0):
+ def __init__(
+ self, partial_reward: float = 0.2, full_reward: float = 1.0, tag: str = "think"
+ ):
self.partial_reward = partial_reward
self.full_reward = full_reward
+ self.tag = tag
+ # Build regex patterns for the specified tag
self._THINK_BLOCK_RE = re.compile(
- r"<\s*think\s*>(.*?)<\s*/\s*think\s*>", re.IGNORECASE | re.DOTALL
+ rf"<\s*{re.escape(tag)}\s*>(.*?)<\s*/\s*{re.escape(tag)}\s*>",
+ re.IGNORECASE | re.DOTALL,
+ )
+ self._THINK_TAG_ATTEMPT_RE = re.compile(
+ rf"<\s*/?\s*{re.escape(tag)}\s*>", re.IGNORECASE
)
- self._THINK_TAG_ATTEMPT_RE = re.compile(r"<\s*/?\s*think\s*>", re.IGNORECASE)
def __call__(self, prompt: str, response: str, target: str | None = None) -> float:
"""Compute thinking reward."""
diff --git a/apps/mast/qwen3_14b_mast.yaml b/apps/grpo/llama3_8b.yaml
similarity index 50%
rename from apps/mast/qwen3_14b_mast.yaml
rename to apps/grpo/llama3_8b.yaml
index 83d5b8103..6a887ebc3 100644
--- a/apps/mast/qwen3_14b_mast.yaml
+++ b/apps/grpo/llama3_8b.yaml
@@ -1,28 +1,23 @@
# Grouped Relative Policy Optimization (GRPO)
-# >>> python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
+# >>> python -m apps.grpo.main --config apps/grpo/llama3_8b.yaml
# Global configuration
-group_size: 8
-batch_size: 16
-max_req_tokens: 512
-max_res_tokens: 512
-model: "Qwen/Qwen3-14B"
+group_size: 4
+local_batch_size: 4 # per-device batch size
+max_req_tokens: 1024
+max_res_tokens: 2048
+model: "meta-llama/Meta-Llama-3.1-8B-Instruct"
off_by_n: 1 # Off by one by default
-launcher: mast
-job_name: forge-qwen3-14b
-checkpoint_folder: /mnt/wsfuse/teamforge/forge_runs/
-
-# Main loop configuration
-rollout_threads: ${services.policy.num_replicas} # Recommended to set equal to policy.num_replicas
+compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
# Observability configuration
metric_logging:
wandb:
- project: "grpo-training"
- group: "grpo_exp_${oc.env:USER}"
- reduce_across_ranks: True
+ project: grpo-training
+ group: grpo_exp_${oc.env:USER}
+ logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
console:
- reduce_across_ranks: True
+ logging_mode: global_reduce
# Dataset configuration
dataset:
@@ -34,15 +29,12 @@ dataset:
# Policy configuration
policy:
- engine_config:
- model: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-14B/snapshots/8268fe3026cb304910457689366670e803a6fd56
+ engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
+ model: ${model}
tensor_parallel_size: 2
pipeline_parallel_size: 1
- enforce_eager: false
- # TODO: Had to disable this becasue vLLm wouldn't like
- # needs to revisited.
- disable_custom_all_reduce: true
- sampling_config:
+ enforce_eager: ${not:${compile}}
+ sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
@@ -51,9 +43,9 @@ policy:
# Trainer configuration
trainer:
model:
- name: qwen3
- flavor: 14B
- hf_assets_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-14B/snapshots/8268fe3026cb304910457689366670e803a6fd56
+ name: llama3
+ flavor: 8B
+ hf_assets_path: hf://${model}
optimizer:
name: AdamW
lr: 1e-5
@@ -61,55 +53,55 @@ trainer:
lr_scheduler:
warmup_steps: 1
training:
- local_batch_size: ${batch_size}
- seq_len: 2048
+ local_batch_size: ${local_batch_size}
+ seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1000000
dtype: bfloat16
gc_freq: 1
compile:
- enable: false
+ enable: ${compile}
parallelism:
data_parallel_replicate_degree: 1
- data_parallel_shard_degree: 4
- tensor_parallel_degree: 2
+ data_parallel_shard_degree: -1
+ tensor_parallel_degree: 1
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
disable_loss_parallel: true
checkpoint:
enable: true
- initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-14B/snapshots/8268fe3026cb304910457689366670e803a6fd56
- initial_load_in_hf: true
+ folder: ./checkpoint # The folder to save checkpoints to.
+ initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists.
+ initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo
last_save_in_hf: true
interval: 500
async_mode: "disabled"
activation_checkpoint:
mode: selective
selective_ac_option: op
- comm:
- # TODO: needs to be revisited. causing NCCL timeouts on inits when loading CP
- # from oilfs if the traienr is not in the same region as in oilfs
- init_timeout_seconds: 1200
- dcp_path: ${checkpoint_folder}
# Replay buffer configuration
replay_buffer:
- batch_size: ${batch_size}
+ batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
- dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree
+ # This should match the dp_size of TorchTitan
+ # Here it's set explicitly to 2, because we've set
+ # 2 GPUs for the trainer and we're using full FSDP.
+ dp_size: 2
# Reference model configuration
ref_model:
model:
- name: qwen3
- flavor: 14B
- hf_assets_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-14B/snapshots/8268fe3026cb304910457689366670e803a6fd56
+ name: llama3
+ flavor: 8B
+ hf_assets_path: hf://${model}
training:
+ seq_len: ${trainer.training.seq_len}
dtype: bfloat16
gc_freq: 1
compile:
- enable: false
+ enable: ${compile}
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
@@ -118,24 +110,21 @@ ref_model:
context_parallel_degree: 1
expert_parallel_degree: 1
checkpoint:
- enable: true
- initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-14B/snapshots/8268fe3026cb304910457689366670e803a6fd56
+ initial_load_path: hf://${model}
initial_load_in_hf: true
# All resource allocations
services:
policy:
- procs: ${policy.engine_config.tensor_parallel_size}
- num_replicas: 2
+ procs: ${policy.engine_args.tensor_parallel_size}
+ num_replicas: 1
with_gpus: true
mesh_name: policy
- hosts: 1
ref_model:
procs: 1
- num_replicas: 2
+ num_replicas: 1
with_gpus: true
mesh_name: ref_model
- hosts: 1
reward_actor:
procs: 1
num_replicas: 1
@@ -148,10 +137,9 @@ actors:
with_gpus: false
mesh_name: dataset
trainer:
- procs: 8
+ procs: 2
with_gpus: true
mesh_name: trainer
- hosts: 1
replay_buffer:
procs: 1
with_gpus: false
diff --git a/apps/grpo/main.py b/apps/grpo/main.py
index 2439100d9..224c7b4d5 100644
--- a/apps/grpo/main.py
+++ b/apps/grpo/main.py
@@ -7,338 +7,110 @@
# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
import asyncio
-import time
import uuid
-from dataclasses import dataclass
-from typing import Any, Callable
import torch
-import torch.nn.functional as F
import torchstore as ts
-from datasets import load_dataset
-from forge.actors._torchstore_utils import (
- get_dcp_whole_state_dict_key,
- get_param_prefix,
-)
-from forge.actors.policy import Policy
+import yaml
+from apps.grpo.data import DatasetActor
+from apps.grpo.grading import MathReward, ThinkingReward
from forge.actors.reference_model import ReferenceModel
from forge.actors.replay_buffer import ReplayBuffer
-from forge.actors.trainer import RLTrainer
-from forge.cli.config import parse
-from forge.controller.actor import ForgeActor
-from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY
+from forge.actors.trainer import TitanTrainer
from forge.controller.provisioner import init_provisioner, shutdown
-from forge.data.rewards import MathReward, ThinkingReward
+from forge.data_models.completion import Completion
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer
-
-from forge.types import (
- Launcher,
- LauncherConfig,
- ProcessConfig,
- ProvisionerConfig,
- ServiceConfig,
-)
+from forge.rl import collate, ComputeAdvantages, Episode, Policy, RewardActor
+from forge.types import LauncherConfig, ProvisionerConfig
+from forge.util.checkpoint import drop_weights
+from forge.util.config import parse
+from forge.util.logging import get_logger
from forge.util.ops import compute_logprobs
-from monarch.actor import endpoint
-from omegaconf import DictConfig
-from vllm.transformers_utils.tokenizer import get_tokenizer
-
-
-@dataclass
-class Episode:
- # TODO: add adtional layer for multi-turn
- episode_id: str
- request: str
- policy_version: int
- pad_id: int
- request_len: int
- response_len: int
- target: Any | None = None
- # processed data
- response: str | None = None
- request_tokens: list[int] | None = None
- response_tokens: list[int] | None = None
- ref_logprobs: torch.Tensor | None = None
- reward: float | None = None
- advantage: float | None = None
-
- @property
- def request_tensor(self):
- tensor = torch.tensor(self.request_tokens, dtype=torch.long)
- if tensor.shape[0] < self.request_len: # left pad
- diff = self.request_len - tensor.shape[0]
- tensor = F.pad(tensor, (diff, 0), value=self.pad_id)
- return tensor
-
- @property
- def response_tensor(self):
- tensor = torch.tensor(self.response_tokens, dtype=torch.long)
- if tensor.shape[0] < self.response_len: # right pad
- diff = self.response_len - tensor.shape[0]
- tensor = F.pad(tensor, (0, diff), value=self.pad_id)
- return tensor
-
-
-@dataclass
-class Group:
- group_id: str
- episodes: list[Episode]
-
- @classmethod
- def new_group(
- cls,
- group_id: int,
- group_size: int,
- request: str,
- policy_version: int,
- pad_id: int,
- request_len: int,
- response_len: int,
- target: Any = None,
- ):
- episodes = []
- for _ in range(group_size):
- episodes.append(
- Episode(
- episode_id=str(uuid.uuid4()),
- request=request,
- policy_version=policy_version,
- pad_id=pad_id,
- request_len=request_len,
- response_len=response_len,
- target=target,
- )
- )
- return cls(str(group_id), episodes)
-
+from omegaconf import DictConfig, OmegaConf
-def collate(batches: list[list[Episode]]):
- inputs = []
- targets = []
- for batch in batches:
- request = [e.request_tensor for e in batch]
- request = torch.stack(request) # [b x s]
-
- response = [e.response_tensor for e in batch]
- response = torch.stack(response) # [b x s]
-
- ref_logprobs = [e.ref_logprobs for e in batch]
- ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s]
-
- advantages = [e.advantage for e in batch]
- advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1]
-
- pad_id = batch[0].pad_id
- mask = response != pad_id
-
- input = {"tokens": torch.cat([request, response], dim=1)}
- target = {
- "response": response,
- "ref_logprobs": ref_logprobs,
- "advantages": advantages,
- "padding_mask": mask,
- }
- inputs.append(input)
- targets.append(target)
- return inputs, targets
+logger = get_logger("INFO")
+# TODO (T245547773): Consolidate with SimpleGRPOLoss in losses/grpo_loss.py
+# Currently duplicated because of function signature differences:
+# - This function takes logits + response, computes logprobs internally
+# - SimpleGRPOLoss takes pre-computed logprobs
+# - TitanTrainer passes logits, so would need wrapper or signature change
+# Consider refactoring TitanTrainer's loss interface to standardize this.
def simple_grpo_loss(
logits: torch.Tensor,
response: torch.Tensor,
ref_logprobs: torch.Tensor,
advantages: torch.Tensor,
padding_mask: torch.Tensor,
- beta: float = 0.1,
+ beta: float = 1e-6,
) -> torch.Tensor:
- """
- Example GRPO Loss Function for RLTrainer
- """
logprobs: torch.Tensor = compute_logprobs(logits, response)
-
- # Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
- per_token_loss = -(per_token_policy_loss - beta * kl)
- loss = (
- ((per_token_loss * padding_mask).sum(dim=1))
- / (padding_mask.sum(dim=1).clamp(min=1.0))
- ).mean()
- return loss
+ # Compute mean KL per valid token
+ mean_kl = (
+ ((kl * padding_mask).sum(dim=1)) / (padding_mask.sum(dim=1).clamp(min=1.0))
+ ).mean()
-@dataclass
-class RewardActor(ForgeActor):
- """Reward actor that uses a list of scoring functions."""
-
- reward_functions: list[Callable]
+ # Compute mean policy loss per valid token
+ mean_policy_loss = (
+ ((per_token_policy_loss * padding_mask).sum(dim=1))
+ / (padding_mask.sum(dim=1).clamp(min=1.0))
+ ).mean()
- @endpoint
- async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
- total_rewards = 0.0
- for reward_fn in self.reward_functions:
- reward = reward_fn(prompt, response, target)
- total_rewards += reward
+ # Compute loss using the means (mathematically equivalent)
+ loss = -(mean_policy_loss - beta * mean_kl)
- # Get a name for the reward function (works for classes, functions, lambdas)
- reward_fn_name = getattr(
- reward_fn, "__name__", reward_fn.__class__.__name__
- )
- # per function reward
- record_metric(
- f"reward/evaluate_response/sum_{reward_fn_name}_reward",
- reward,
- Reduce.SUM,
- )
- record_metric(
- f"reward/evaluate_response/avg_{reward_fn_name}_reward",
- reward,
- Reduce.MEAN,
- )
- record_metric(
- f"reward/evaluate_response/std_{reward_fn_name}_reward",
- reward,
- Reduce.STD,
- )
+ # Log metrics
+ # TODO: Better design - have loss function return all metrics as a dict,
+ # then record them in rl_trainer so all training metrics are in one namespace
+ # and we avoid doing .item here, which is not compile friendly
+ record_metric("grpo_loss/kl_divergence_mean", mean_kl.item(), Reduce.MEAN)
+ record_metric(
+ "grpo_loss/kl_divergence_max", (kl * padding_mask).max().item(), Reduce.MAX
+ )
+ record_metric(
+ "grpo_loss/policy_gradient_loss", mean_policy_loss.item(), Reduce.MEAN
+ )
+ record_metric("grpo_loss/total_loss", loss.item(), Reduce.MEAN)
+ record_metric("grpo_loss/advantage_mean", advantages.mean().item(), Reduce.MEAN)
+ record_metric("grpo_loss/advantage_std", advantages.std().item(), Reduce.MEAN)
+ return loss
- # avg total reward
- record_metric(
- "reward/evaluate_response/avg_total_reward",
- reward,
- Reduce.MEAN,
- )
- # count fn calls
- record_metric(
- f"reward/evaluate_response/count_{reward_fn_name}_calls",
- 1,
- Reduce.SUM,
- )
+async def main(cfg: DictConfig):
+ """Main GRPO training loop with rollout and training processes."""
+ # Convert OmegaConf config to plain dict
+ run_config_for_logging = OmegaConf.to_container(cfg, resolve=True)
- avg_reward = total_rewards / len(self.reward_functions)
- return avg_reward
-
-
-@dataclass
-class ComputeAdvantages(ForgeActor):
- """Compute advantages for GRPO using reward signals."""
-
- @endpoint
- async def compute(self, group: Group) -> list[float]:
- # TODO: add batch processing
- rewards = torch.tensor([[e.reward for e in group.episodes]])
- mean = rewards.mean(1, keepdim=True)
- std = rewards.std(1, keepdim=True)
- advantages = (rewards - mean) / (std + 1e-4)
- return advantages.squeeze(0).tolist()
-
-
-@dataclass
-class DatasetActor(ForgeActor):
- """Actor wrapper for HuggingFace dataset to provide async interface."""
-
- path: str = "openai/gsm8k"
- revision: str = "main"
- data_split: str = "train"
- streaming: bool = True
- model: str = "Qwen/Qwen3-1.7B"
-
- @endpoint
- def setup(self):
- self._tokenizer = get_tokenizer(self.model)
-
- def gsm8k_transform(sample):
- system_prompt = """
- Put all your scratchpad work between and tags.
- Your final answer should be between and tags otherwise it will not be scored.
- """
- request: str = sample["question"]
- as_chat = [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": request},
- ]
- formatted_request = self._tokenizer.apply_chat_template(
- as_chat,
- tokenize=False,
- add_generation_prompt=True,
- )
- target: str = sample["answer"]
- formatted_target = target.split("#### ")[1]
- return {"request": formatted_request, "target": formatted_target}
+ # Log config
+ logger.info("=" * 30 + " CONFIGURATION " + "=" * 30)
+ logger.info(
+ yaml.dump(run_config_for_logging, default_flow_style=False, sort_keys=False)
+ )
- ds = load_dataset(
- self.path, self.revision, split=self.data_split, streaming=self.streaming
+ # ---- Global setups ---- #
+ provisioner = None
+ if cfg.get("provisioner", None) is not None:
+ provisioner = await init_provisioner(
+ ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
)
- ds = ds.map(gsm8k_transform)
- ds = ds.shuffle()
- self._iterator = iter(ds)
+ else:
+ provisioner = await init_provisioner()
- @endpoint
- async def sample(self) -> dict[str, str] | None:
- try:
- sample = next(self._iterator)
-
- # Record dataset metrics
- record_metric("dataset/sample/count_samples_generated", 1, Reduce.SUM)
- record_metric(
- "dataset/sample/avg_sample_len",
- len(sample["request"]),
- Reduce.MEAN,
- )
-
- return sample
- except StopIteration:
- return None
-
- @endpoint
- async def pad_token(self):
- return self._tokenizer.pad_token_id
-
-
-async def drop_weights(version: int):
- print(f"Dropping weights @ version {version}")
- start_time = time.perf_counter()
- prefix = get_param_prefix(version)
- matching_keys = await ts.keys(prefix)
- # TODO: once we have something like `get_meta()` in torchstore, we can just
- # query the type of the object instead of relying on keys.
- dcp_key = get_dcp_whole_state_dict_key(version)
- if dcp_key in matching_keys:
- dcp_handle = await ts.get(dcp_key)
- dcp_handle.drop()
- for key in matching_keys:
- await ts.delete(key)
- elapsed = time.perf_counter() - start_time
- print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds")
-
-
-async def main(cfg: DictConfig):
- """Main GRPO training loop with rollout and training processes."""
- group_size = cfg.group_size
- max_req_tokens = cfg.max_req_tokens
- max_res_tokens = cfg.max_res_tokens
+ metric_logging_cfg = cfg.get("metric_logging", {})
- # init provisioner
- await init_provisioner(
- ProvisionerConfig(
- launcher_config=LauncherConfig(
- launcher=Launcher(cfg.get(LAUNCHER_KEY, Launcher.SLURM.value)),
- job_name=cfg.get(JOB_NAME_KEY, None),
- services={k: ServiceConfig(**v) for k, v in cfg.services.items()},
- actors={k: ProcessConfig(**v) for k, v in cfg.actors.items()},
- )
- )
+ mlogger = await get_or_create_metric_logger(process_name="Controller")
+ await mlogger.init_backends.call_one(
+ backend_config=metric_logging_cfg, run_config=run_config_for_logging
)
- # initialize before spawning services
- metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
- mlogger = await get_or_create_metric_logger()
- await mlogger.init_backends.call_one(metric_logging_cfg)
-
# ---- Setup services ---- #
- await ts.initialize(strategy=ts.ControllerStorageVolumes())
+
(
dataloader,
policy,
@@ -350,7 +122,7 @@ async def main(cfg: DictConfig):
) = await asyncio.gather(
DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),
Policy.options(**cfg.services.policy).as_service(**cfg.policy),
- RLTrainer.options(**cfg.actors.trainer).as_actor(
+ TitanTrainer.options(**cfg.actors.trainer).as_actor(
**cfg.trainer, loss=simple_grpo_loss
),
ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(
@@ -363,13 +135,34 @@ async def main(cfg: DictConfig):
),
)
+ group_size = cfg.group_size
+ max_req_tokens = cfg.max_req_tokens
+ max_res_tokens = cfg.max_res_tokens
+
+ # Set max_steps to the configured value, or -1 if not specified or Null
+ max_steps = cfg.trainer.training.steps or -1
+
print("All services initialized successfully!")
+ shutdown_event = asyncio.Event()
+ # Here we spawn a torchstore storage volume per trainer process.
+ # We initialize after service initialization because torchstore currently
+ # requires access to the underlying proc meshes in the local rank strategy.
+ # We should be able to hide this in the future.
+ # TODO: support multiple host meshes
+ trainer_num_procs = cfg.actors.trainer["procs"]
+ trainer_host_mesh_name = cfg.actors.trainer["mesh_name"]
+ trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name)
+ await ts.initialize(
+ mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}),
+ strategy=ts.LocalRankStrategy(),
+ )
+ print("Torchstore successfully initialized with local rank strategy")
# ---- Core RL loops ---- #
async def continuous_rollouts():
rollout_count = 0
pad_id = await dataloader.pad_token.call_one()
- while True:
+ while not shutdown_event.is_set():
t = Tracer("main_perf/continuous_rollouts")
t.start()
sample = await dataloader.sample.call_one()
@@ -377,47 +170,89 @@ async def continuous_rollouts():
print("Dataloader is empty, exiting continuous rollout")
return
- t.step("data_loading")
-
prompt, target = sample["request"], sample["target"]
- responses = await policy.generate.route(prompt)
- # TODO: this shall be part of the responses metadata instead of a separate call
- version = await policy.get_version.route()
-
+ responses: list[Completion] = await policy.generate.route(prompt)
t.step("policy_generation")
- assert (
- len(responses) > 0
- ), "Sanity check: Responses should NEVER return empty"
- assert (
- version := responses[0].generator_version
- ) is not None, "Response must indicate a version"
- group = Group.new_group(
- group_id=rollout_count,
- group_size=group_size,
- request=prompt,
- policy_version=version,
- pad_id=pad_id,
- request_len=max_req_tokens,
- response_len=max_res_tokens,
- target=target,
- )
-
+ # Construct episodes and calculate rewards
+ episodes = []
input_ids = torch.ones(
(group_size, max_req_tokens + max_res_tokens),
dtype=torch.long,
- device="cuda",
)
- # Populate episode info and calculate rewards
- for i, (episode, response) in enumerate(zip(group.episodes, responses)):
- episode.request_tokens = response.prompt_ids
- episode.response_tokens = response.token_ids
- episode.response = response.text
+ for i, response in enumerate(responses):
+ episode = Episode(
+ episode_id=str(uuid.uuid4()),
+ pad_id=pad_id,
+ request_len=max_req_tokens,
+ response_len=max_res_tokens,
+ target=target,
+ request=prompt,
+ response=response.text,
+ completion=response,
+ )
+ (
+ episode.reward_breakdown,
+ episode.reward,
+ ) = await reward_actor.evaluate_response.route(
+ prompt=prompt, response=response.text, target=target
+ )
+ episodes.append(episode)
+
+ # Build input_ids for reference logprobs
input_ids[i, :max_req_tokens] = episode.request_tensor
input_ids[i, max_req_tokens:] = episode.response_tensor
- episode.reward = await reward_actor.evaluate_response.route(
- prompt=prompt, response=response.text, target=target
+
+ # Track token-based metrics
+ prompt_tokens = episode.completion.prompt_ids.shape[0]
+ response_tokens = episode.completion.token_ids.shape[0]
+
+ record_metric("episode/avg_prompt_tokens", prompt_tokens, Reduce.MEAN)
+ record_metric("episode/max_prompt_tokens", prompt_tokens, Reduce.MAX)
+ record_metric("episode/min_prompt_tokens", prompt_tokens, Reduce.MIN)
+ record_metric(
+ "episode/avg_response_tokens", response_tokens, Reduce.MEAN
+ )
+ record_metric(
+ "episode/max_response_tokens", response_tokens, Reduce.MAX
)
+ record_metric(
+ "episode/min_response_tokens", response_tokens, Reduce.MIN
+ )
+
+ # drop episodes if
+ # 1> reward std-dev is very small (including all 0s and all 1s)
+ # 2> any response was truncated (didn't end with EOS)
+ # TODO: change it to filter only truncated episodes instead of dropping entire group
+ rewards = [e.reward for e in episodes]
+ rewards_std = torch.std(torch.tensor(rewards))
+ is_low_variance = rewards_std < 1e-3
+ num_truncated = sum(
+ 1 for e in episodes if e.completion.stop_reason == "length"
+ )
+ is_truncated = num_truncated > 0
+ drop = is_low_variance or is_truncated
+
+ n = len(episodes)
+ record_metric(
+ "main/continuous_rollouts/episodes_dropped/low_variance",
+ n if is_low_variance else 0,
+ Reduce.SUM,
+ )
+ record_metric(
+ "main/continuous_rollouts/episodes_dropped/truncated",
+ num_truncated,
+ Reduce.SUM,
+ )
+ record_metric(
+ "main/continuous_rollouts/episodes_dropped/total",
+ n if drop else 0,
+ Reduce.SUM,
+ )
+
+ if drop:
+ del input_ids, episodes
+ continue
t.step("reward_evaluation")
@@ -426,18 +261,23 @@ async def continuous_rollouts():
)
t.step("reference_model_calculate_logprobs")
- for i, episode in enumerate(group.episodes):
+ for i, episode in enumerate(episodes):
episode.ref_logprobs = ref_logprobs[i]
del ref_logprobs, input_ids
- t.step("compute_logprobs")
- # Calculate advantages and add to replay buffer
- advantages = await compute_advantages.compute.call_one(group)
- for episode, advantage in zip(group.episodes, advantages):
+ advantages = await compute_advantages.compute.call_one(episodes)
+ for episode, advantage in zip(episodes, advantages):
episode.advantage = advantage
await replay_buffer.add.call_one(episode)
- # Log metrics
+ sample = episode.to_dict(exclude=["ref_logprobs", "completion"])
+ sample["score"] = sample["reward"]
+ record_metric(
+ "main_samples/continuous_rollouts/sample_table",
+ sample,
+ Reduce.SAMPLE,
+ )
+
rollout_count += 1
record_metric(
"main/continuous_rollouts/count_rollout_iterations", 1, Reduce.SUM
@@ -448,7 +288,7 @@ async def continuous_training():
training_step = 0
restart_tracer = True # Flag to control when to restart tracer
- while True:
+ while max_steps == -1 or training_step < max_steps:
# Restart tracer when needed (initial start or after completing a training step)
# Otherwise, we cannot measure time waiting for buffer
if restart_tracer:
@@ -485,6 +325,10 @@ async def continuous_training():
# Flush metrics every training step to WandB
await mlogger.flush.call_one(training_step)
+ print(
+ f"Reached training limit ({max_steps} steps). Exiting continuous_training loop."
+ )
+
num_rollout_threads = cfg.get("rollout_threads", 1)
num_training_threads = cfg.get("training_threads", 1)
print(
@@ -496,31 +340,27 @@ async def continuous_training():
training_task = asyncio.create_task(continuous_training())
try:
- await asyncio.gather(*rollout_tasks, training_task)
+ await training_task
except KeyboardInterrupt:
print("Training interrupted by user")
- for rollout_task in rollout_tasks:
- rollout_task.cancel()
- training_task.cancel()
finally:
- print("Shutting down...")
-
- # give mlogger time to shutdown backends, otherwise they can stay running.
- # TODO (felipemello) find more elegant solution
- await mlogger.shutdown.call_one()
- await asyncio.sleep(2)
-
- await asyncio.gather(
- DatasetActor.shutdown(dataloader),
- policy.shutdown(),
- RLTrainer.shutdown(trainer),
- ReplayBuffer.shutdown(replay_buffer),
- ComputeAdvantages.shutdown(compute_advantages),
- ref_model.shutdown(),
- reward_actor.shutdown(),
- )
- # TODO - add a global shutdown that implicitly shuts down all services
- # and remote allocations
+ print("Shutting down... (this may take a few seconds)")
+ shutdown_event.set()
+
+ try:
+ # Give rollouts up to 5s to finish naturally
+ await asyncio.wait_for(
+ asyncio.gather(*rollout_tasks, return_exceptions=True),
+ timeout=5,
+ )
+ except asyncio.TimeoutError:
+ print("Timeout waiting for rollouts; forcing cancellation...")
+ for t in rollout_tasks:
+ t.cancel()
+ await asyncio.gather(*rollout_tasks, return_exceptions=True)
+
+ training_task.cancel()
+
await shutdown()
diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml
index 53eec5cfb..b0c08a28c 100644
--- a/apps/grpo/qwen3_1_7b.yaml
+++ b/apps/grpo/qwen3_1_7b.yaml
@@ -3,11 +3,12 @@
# Global configuration
group_size: 8
-batch_size: 16
-max_req_tokens: 512
-max_res_tokens: 512
+local_batch_size: 16 # per-device batch size
+max_req_tokens: 1024
+max_res_tokens: 2048
model: "Qwen/Qwen3-1.7B"
off_by_n: 1 # Off by one by default
+compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
# Main loop configuration
rollout_threads: 1 # Recommended to set equal to policy.num_replicas
@@ -16,11 +17,11 @@ rollout_threads: 1 # Recommended to set equal to policy.num_replicas
# Observability configuration
metric_logging:
wandb:
- project: "grpo-training"
- group: "grpo_exp_${oc.env:USER}"
- reduce_across_ranks: True
+ project: grpo-training
+ group: grpo_exp_${oc.env:USER}
+ logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
console:
- reduce_across_ranks: True
+ logging_mode: global_reduce
# Dataset configuration
dataset:
@@ -32,12 +33,12 @@ dataset:
# Policy configuration
policy:
- engine_config:
+ engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
model: ${model}
tensor_parallel_size: 1
pipeline_parallel_size: 1
- enforce_eager: false
- sampling_config:
+ enforce_eager: ${not:${compile}}
+ sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
@@ -56,14 +57,14 @@ trainer:
lr_scheduler:
warmup_steps: 1
training:
- local_batch_size: ${batch_size}
- seq_len: 2048
+ local_batch_size: ${local_batch_size}
+ seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1000000
dtype: bfloat16
gc_freq: 1
compile:
- enable: false
+ enable: ${compile}
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
@@ -74,8 +75,9 @@ trainer:
disable_loss_parallel: true
checkpoint:
enable: true
- initial_load_path: hf://${model}
- initial_load_in_hf: true
+ folder: ./checkpoint # The folder to save checkpoints to.
+ initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists.
+ initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo
last_save_in_hf: true
interval: 500
async_mode: "disabled"
@@ -85,7 +87,7 @@ trainer:
# Replay buffer configuration
replay_buffer:
- batch_size: ${batch_size}
+ batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree
@@ -96,10 +98,11 @@ ref_model:
flavor: 1.7B
hf_assets_path: hf://${model}
training:
+ seq_len: ${trainer.training.seq_len}
dtype: bfloat16
gc_freq: 1
compile:
- enable: false
+ enable: ${compile}
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
@@ -115,28 +118,35 @@ ref_model:
# All resource allocations
services:
policy:
- procs: ${policy.engine_config.tensor_parallel_size}
+ procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 1
+ mesh_name: policy
with_gpus: true
ref_model:
procs: 1
num_replicas: 1
+ mesh_name: ref_model
with_gpus: true
reward_actor:
procs: 1
num_replicas: 1
+ mesh_name: reward_actor
with_gpus: false
actors:
dataset:
procs: 1
with_gpus: false
+ mesh_name: dataset
trainer:
procs: 1
with_gpus: true
+ mesh_name: trainer
replay_buffer:
procs: 1
with_gpus: false
+ mesh_name: replay_buffer
compute_advantages:
procs: 1
with_gpus: false
+ mesh_name: compute_advantages
diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml
index c46ee0620..9bec6a541 100644
--- a/apps/grpo/qwen3_8b.yaml
+++ b/apps/grpo/qwen3_8b.yaml
@@ -2,21 +2,22 @@
# >>> python -m apps.grpo.main --config apps/grpo/qwen3_8b.yaml
# Global configuration
-group_size: 8
-batch_size: 16
-max_req_tokens: 512
-max_res_tokens: 512
+group_size: 16
+local_batch_size: 4 # per-device batch size
+max_req_tokens: 1024
+max_res_tokens: 2048
model: "Qwen/Qwen3-8B"
off_by_n: 1 # Off by one by default
+compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
# Observability configuration
metric_logging:
wandb:
- project: "grpo-training"
- group: "grpo_exp_${oc.env:USER}"
- reduce_across_ranks: True
+ project: grpo-training
+ group: grpo_exp_${oc.env:USER}
+ logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
console:
- reduce_across_ranks: True
+ logging_mode: global_reduce
# Dataset configuration
dataset:
@@ -28,13 +29,12 @@ dataset:
# Policy configuration
policy:
- use_vllm_builtin_load: true
- engine_config:
+ engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
model: ${model}
tensor_parallel_size: 2
pipeline_parallel_size: 1
- enforce_eager: false
- sampling_config:
+ enforce_eager: ${not:${compile}}
+ sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
@@ -42,8 +42,6 @@ policy:
# Trainer configuration
trainer:
- use_dcp: true
- use_vllm_builtin_load: true
model:
name: qwen3
flavor: 8B
@@ -55,14 +53,14 @@ trainer:
lr_scheduler:
warmup_steps: 1
training:
- local_batch_size: ${batch_size}
- seq_len: 2048
+ local_batch_size: ${local_batch_size}
+ seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1000000
dtype: bfloat16
gc_freq: 1
compile:
- enable: false
+ enable: ${compile}
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: -1
@@ -73,8 +71,9 @@ trainer:
disable_loss_parallel: true
checkpoint:
enable: true
- initial_load_path: hf://${model}
- initial_load_in_hf: true
+ folder: ./checkpoint # The folder to save checkpoints to.
+ initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists.
+ initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo
last_save_in_hf: true
interval: 500
async_mode: "disabled"
@@ -84,7 +83,7 @@ trainer:
# Replay buffer configuration
replay_buffer:
- batch_size: ${batch_size}
+ batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
# This should match the dp_size of TorchTitan
# Here it's set explicitly to 2, because we've set
@@ -98,10 +97,11 @@ ref_model:
flavor: 8B
hf_assets_path: hf://${model}
training:
+ seq_len: ${trainer.training.seq_len}
dtype: bfloat16
gc_freq: 1
compile:
- enable: false
+ enable: ${compile}
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
@@ -116,28 +116,35 @@ ref_model:
# All resource allocations
services:
policy:
- procs: ${policy.engine_config.tensor_parallel_size}
+ procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 1
with_gpus: true
+ mesh_name: policy
ref_model:
procs: 1
num_replicas: 1
with_gpus: true
+ mesh_name: ref_model
reward_actor:
procs: 1
num_replicas: 1
with_gpus: false
+ mesh_name: reward_actor
actors:
dataset:
procs: 1
with_gpus: false
+ mesh_name: dataset
trainer:
procs: 2
with_gpus: true
+ mesh_name: trainer
replay_buffer:
procs: 1
with_gpus: false
+ mesh_name: replay_buffer
compute_advantages:
procs: 1
with_gpus: false
+ mesh_name: compute_advantages
diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/slurm/qwen3_30b_a3b.yaml
similarity index 57%
rename from apps/grpo/qwen3_32b.yaml
rename to apps/grpo/slurm/qwen3_30b_a3b.yaml
index 3d1b80852..d4f35ba72 100644
--- a/apps/grpo/qwen3_32b.yaml
+++ b/apps/grpo/slurm/qwen3_30b_a3b.yaml
@@ -1,26 +1,34 @@
# Grouped Relative Policy Optimization (GRPO)
-# >>> python -m apps.grpo.main --config apps/grpo/qwen32b.yaml
# NOTE - This has not been tested for correctness yet! All testing so far has been only for infrastructure stability
+# ./apps/grpo/slurm/submit.sh qwen3_30b_a3b
# Global configuration
-group_size: 2
-batch_size: 8
-max_req_tokens: 512
-max_res_tokens: 512
-model: "Qwen/Qwen3-32B"
+group_size: 4
+local_batch_size: 1 # per-device batch size
+max_req_tokens: 1024
+max_res_tokens: 1024
+model: "Qwen/Qwen3-30B-A3B"
off_by_n: 1 # Off by one by default
+provisioner:
+ launcher: slurm
+ memMB: 2047962
+ cpu: 192
+ account: agentic-models
+ qos: h200_capabilities_shared
+
# Main loop configuration
-rollout_threads: 1 # Recommended to set equal to policy.num_replicas
+rollout_threads: 32 # make this 4x the number of policy replicas seems to work well
# Observability configuration
metric_logging:
wandb:
- project: "grpo-training"
- group: "grpo_exp_${oc.env:USER}"
- reduce_across_ranks: True
+ entity: agentic-models
+ project: grpo-training
+ group: grpo_exp_${oc.env:USER}
+ logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
console:
- reduce_across_ranks: True
+ logging_mode: global_reduce
# Dataset configuration
dataset:
@@ -32,12 +40,12 @@ dataset:
# Policy configuration
policy:
- engine_config:
+ engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
model: ${model}
tensor_parallel_size: 4
pipeline_parallel_size: 1
enforce_eager: false
- sampling_config:
+ sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
@@ -47,7 +55,7 @@ policy:
trainer:
model:
name: qwen3
- flavor: 32B
+ flavor: 30B-A3B
hf_assets_path: hf://${model}
optimizer:
name: AdamW
@@ -56,8 +64,8 @@ trainer:
lr_scheduler:
warmup_steps: 1
training:
- local_batch_size: ${batch_size}
- seq_len: 2048
+ local_batch_size: ${local_batch_size}
+ seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1000000
dtype: bfloat16
@@ -71,11 +79,13 @@ trainer:
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
+ expert_tensor_parallel_degree: 1
disable_loss_parallel: true
checkpoint:
enable: true
- initial_load_path: hf://${model}
- initial_load_in_hf: true
+ folder: ./checkpoint # The folder to save checkpoints to.
+ initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists.
+ initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo
last_save_in_hf: true
interval: 500
async_mode: "disabled"
@@ -84,26 +94,27 @@ trainer:
# Replay buffer configuration
replay_buffer:
- batch_size: ${batch_size}
+ batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
# dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree
- dp_size: 8
+ dp_size: 4
# Reference model configuration
ref_model:
model:
name: qwen3
- flavor: 32B
+ flavor: 30B-A3B
hf_assets_path: hf://${model}
training:
+ seq_len: ${trainer.training.seq_len}
dtype: bfloat16
gc_freq: 1
compile:
enable: false
parallelism:
data_parallel_replicate_degree: 1
- data_parallel_shard_degree: 1
- tensor_parallel_degree: 4
+ data_parallel_shard_degree: -1
+ tensor_parallel_degree: 1
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
@@ -115,30 +126,37 @@ ref_model:
# All resource allocations
services:
policy:
- procs: ${policy.engine_config.tensor_parallel_size}
+ procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 1
hosts: 1
with_gpus: true
+ mesh_name: policy
ref_model:
- procs: ${ref_model.parallelism.tensor_parallel_degree}
+ procs: 4
num_replicas: 1
with_gpus: true
+ mesh_name: ref_model
reward_actor:
procs: 1
num_replicas: 1
with_gpus: false
+ mesh_name: reward_actor
actors:
dataset:
procs: 1
with_gpus: false
+ mesh_name: dataset
trainer:
- procs: 8
+ procs: 4
hosts: 1
with_gpus: true
+ mesh_name: trainer
replay_buffer:
procs: 1
with_gpus: false
+ mesh_name: replay_buffer
compute_advantages:
procs: 1
with_gpus: false
+ mesh_name: compute_advantages
diff --git a/apps/grpo/slurm/qwen3_32b.yaml b/apps/grpo/slurm/qwen3_32b.yaml
new file mode 100644
index 000000000..ca4399558
--- /dev/null
+++ b/apps/grpo/slurm/qwen3_32b.yaml
@@ -0,0 +1,162 @@
+# Grouped Relative Policy Optimization (GRPO)
+# NOTE - This has not been tested for correctness yet! All testing so far has been only for infrastructure stability
+# ./apps/grpo/slurm/submit.sh qwen3_32b
+
+# Global configuration
+group_size: 16
+local_batch_size: 2 # per-device batch size
+max_req_tokens: 1024
+max_res_tokens: 1024
+model: "Qwen/Qwen3-32B"
+off_by_n: 1 # Off by one by default
+compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
+
+provisioner:
+ launcher: slurm
+ memMB: 2047962
+ cpu: 192
+ account: agentic-models
+ qos: h200_capabilities_shared
+
+# Main loop configuration
+rollout_threads: 32 # make this 4x the number of policy replicas seems to work well
+
+# Observability configuration
+metric_logging:
+ wandb:
+ entity: agentic-models
+ project: grpo-training
+ group: grpo_exp_${oc.env:USER}
+ logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
+ console:
+ logging_mode: global_reduce
+
+# Dataset configuration
+dataset:
+ path: "openai/gsm8k"
+ revision: "main"
+ data_split: "train"
+ streaming: true
+ model: ${model}
+
+# Policy configuration
+policy:
+ engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
+ model: ${model}
+ tensor_parallel_size: 4
+ pipeline_parallel_size: 1
+ enforce_eager: ${not:${compile}}
+ sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
+ n: ${group_size}
+ max_tokens: ${max_res_tokens}
+ temperature: 1.0
+ top_p: 1.0
+
+# Trainer configuration
+trainer:
+ model:
+ name: qwen3
+ flavor: 32B
+ hf_assets_path: hf://${model}
+ optimizer:
+ name: AdamW
+ lr: 1e-5
+ eps: 1e-8
+ lr_scheduler:
+ warmup_steps: 1
+ training:
+ local_batch_size: ${local_batch_size}
+ seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
+ max_norm: 1.0
+ steps: 1000000
+ dtype: bfloat16
+ gc_freq: 1
+ compile:
+ enable: ${compile}
+ parallelism:
+ data_parallel_replicate_degree: 1
+ data_parallel_shard_degree: 1
+ tensor_parallel_degree: 8
+ pipeline_parallel_degree: 1
+ context_parallel_degree: 1
+ expert_parallel_degree: 1
+ disable_loss_parallel: true
+ checkpoint:
+ enable: true
+ folder: ./checkpoint # The folder to save checkpoints to.
+ initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists.
+ initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo
+ last_save_in_hf: true
+ interval: 500
+ async_mode: "disabled"
+ activation_checkpoint:
+ mode: full
+
+# Replay buffer configuration
+replay_buffer:
+ batch_size: ${local_batch_size}
+ max_policy_age: ${off_by_n}
+ # dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree
+ dp_size: 1
+
+# Reference model configuration
+ref_model:
+ model:
+ name: qwen3
+ flavor: 32B
+ hf_assets_path: hf://${model}
+ training:
+ seq_len: ${trainer.training.seq_len}
+ dtype: bfloat16
+ gc_freq: 1
+ compile:
+ enable: ${compile}
+ parallelism:
+ data_parallel_replicate_degree: 1
+ data_parallel_shard_degree: 1
+ tensor_parallel_degree: 4
+ pipeline_parallel_degree: 1
+ context_parallel_degree: 1
+ expert_parallel_degree: 1
+ checkpoint:
+ enable: true
+ initial_load_path: hf://${model}
+ initial_load_in_hf: true
+
+# All resource allocations
+services:
+ policy:
+ procs: ${policy.engine_args.tensor_parallel_size}
+ num_replicas: 4
+ hosts: 1
+ with_gpus: true
+ mesh_name: policy
+ ref_model:
+ procs: ${ref_model.parallelism.tensor_parallel_degree}
+ num_replicas: 1
+ with_gpus: true
+ mesh_name: ref_model
+ reward_actor:
+ procs: 1
+ num_replicas: 1
+ with_gpus: false
+ mesh_name: reward_actor
+
+actors:
+ dataset:
+ procs: 1
+ with_gpus: false
+ mesh_name: dataset
+ trainer:
+ procs: 8
+ hosts: 1
+ with_gpus: true
+ mesh_name: trainer
+ replay_buffer:
+ procs: 1
+ with_gpus: false
+ mesh_name: replay_buffer
+ compute_advantages:
+ procs: 1
+ with_gpus: false
+ mesh_name: compute_advantages
diff --git a/apps/mast/qwen3_8b_mast.yaml b/apps/grpo/slurm/qwen3_8b.yaml
similarity index 54%
rename from apps/mast/qwen3_8b_mast.yaml
rename to apps/grpo/slurm/qwen3_8b.yaml
index 7f2f99694..0922f2078 100644
--- a/apps/mast/qwen3_8b_mast.yaml
+++ b/apps/grpo/slurm/qwen3_8b.yaml
@@ -1,28 +1,32 @@
# Grouped Relative Policy Optimization (GRPO)
-# >>> python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
+# ./apps/grpo/slurm/submit.sh qwen3_8b
# Global configuration
-group_size: 8
-batch_size: 16
-max_req_tokens: 512
-max_res_tokens: 512
+group_size: 16
+local_batch_size: 4 # per-device batch size
+max_req_tokens: 1024
+max_res_tokens: 2048
model: "Qwen/Qwen3-8B"
off_by_n: 1 # Off by one by default
-launcher: mast
-job_name: forge-qwen3-8b
-checkpoint_folder: /mnt/wsfuse/teamforge/forge_runs/
+compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
-# Main loop configuration
-rollout_threads: ${services.policy.num_replicas} # Recommended to set equal to policy.num_replicas
+
+provisioner:
+ launcher: slurm
+ memMB: 2047962
+ cpu: 192
+ account: agentic-models
+ qos: h200_capabilities_shared
# Observability configuration
metric_logging:
wandb:
- project: "grpo-training"
- group: "grpo_exp_${oc.env:USER}"
- reduce_across_ranks: True
+ entity: agentic-models
+ project: grpo-training
+ group: grpo_exp_${oc.env:USER}
+ logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
console:
- reduce_across_ranks: True
+ logging_mode: global_reduce
# Dataset configuration
dataset:
@@ -34,15 +38,12 @@ dataset:
# Policy configuration
policy:
- engine_config:
- model: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-8B/snapshots/model
+ engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
+ model: ${model}
tensor_parallel_size: 2
pipeline_parallel_size: 1
- enforce_eager: false
- # TODO: Had to disable this becasue vLLm wouldn't like
- # needs to revisited.
- disable_custom_all_reduce: true
- sampling_config:
+ enforce_eager: ${not:${compile}}
+ sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
@@ -53,7 +54,7 @@ trainer:
model:
name: qwen3
flavor: 8B
- hf_assets_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-8B/snapshots/model
+ hf_assets_path: hf://${model}
optimizer:
name: AdamW
lr: 1e-5
@@ -61,55 +62,55 @@ trainer:
lr_scheduler:
warmup_steps: 1
training:
- local_batch_size: ${batch_size}
- seq_len: 2048
+ local_batch_size: ${local_batch_size}
+ seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1000000
dtype: bfloat16
gc_freq: 1
compile:
- enable: false
+ enable: ${compile}
parallelism:
data_parallel_replicate_degree: 1
- data_parallel_shard_degree: 4
- tensor_parallel_degree: 2
+ data_parallel_shard_degree: -1
+ tensor_parallel_degree: 1
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
disable_loss_parallel: true
checkpoint:
enable: true
- initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-8B/snapshots/model
- initial_load_in_hf: true
+ folder: ./checkpoint # The folder to save checkpoints to.
+ initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists.
+ initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo
last_save_in_hf: true
interval: 500
async_mode: "disabled"
activation_checkpoint:
mode: selective
selective_ac_option: op
- comm:
- # TODO: needs to be revisited. causing NCCL timeouts on inits when loading CP
- # from oilfs if the traienr is not in the same region as in oilfs
- init_timeout_seconds: 1200
- dcp_path: ${checkpoint_folder}
# Replay buffer configuration
replay_buffer:
- batch_size: ${batch_size}
+ batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
- dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree
+ # This should match the dp_size of TorchTitan
+ # Here it's set explicitly to 2, because we've set
+ # 2 GPUs for the trainer and we're using full FSDP.
+ dp_size: 2
# Reference model configuration
ref_model:
model:
name: qwen3
flavor: 8B
- hf_assets_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-8B/snapshots/model
+ hf_assets_path: hf://${model}
training:
+ seq_len: ${trainer.training.seq_len}
dtype: bfloat16
gc_freq: 1
compile:
- enable: false
+ enable: ${compile}
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
@@ -118,24 +119,22 @@ ref_model:
context_parallel_degree: 1
expert_parallel_degree: 1
checkpoint:
- enable: true
- initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-8B/snapshots/model
+ initial_load_path: hf://${model}
initial_load_in_hf: true
# All resource allocations
services:
policy:
- procs: ${policy.engine_config.tensor_parallel_size}
- num_replicas: 2
+ procs: ${policy.engine_args.tensor_parallel_size}
+ num_replicas: 1
+ hosts: 1
with_gpus: true
mesh_name: policy
- hosts: 1
ref_model:
procs: 1
- num_replicas: 2
+ num_replicas: 1
with_gpus: true
mesh_name: ref_model
- hosts: 1
reward_actor:
procs: 1
num_replicas: 1
@@ -148,10 +147,9 @@ actors:
with_gpus: false
mesh_name: dataset
trainer:
- procs: 8
+ procs: 2
with_gpus: true
mesh_name: trainer
- hosts: 1
replay_buffer:
procs: 1
with_gpus: false
diff --git a/apps/grpo/slurm/submit.sh b/apps/grpo/slurm/submit.sh
new file mode 100755
index 000000000..dd26eaa70
--- /dev/null
+++ b/apps/grpo/slurm/submit.sh
@@ -0,0 +1,18 @@
+#!/bin/bash
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+CONFIG_NAME="${1}"
+
+sbatch --job-name="${CONFIG_NAME}" \
+ --export=ALL,CONFIG_NAME="${CONFIG_NAME}" \
+ apps/grpo/slurm/submit_grpo.sh
+
+
+# Usage:
+# ./apps/grpo/slurm/submit.sh qwen3_8b
+# ./apps/grpo/slurm/submit.sh qwen3_32b
+# ./apps/grpo/slurm/submit.sh qwen3_30b_a3b
diff --git a/apps/grpo/slurm/submit_grpo.sh b/apps/grpo/slurm/submit_grpo.sh
new file mode 100755
index 000000000..24e8447ad
--- /dev/null
+++ b/apps/grpo/slurm/submit_grpo.sh
@@ -0,0 +1,29 @@
+#!/bin/bash
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#SBATCH --qos=h200_capabilities_shared
+#SBATCH --account=agentic-models
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=1
+#SBATCH --gpus-per-node=8
+#SBATCH --cpus-per-task=128
+#SBATCH --mem=500G
+#SBATCH --time=72:00:00
+
+echo "Starting GRPO training job"
+
+eval "$(conda shell.bash hook)"
+
+conda activate forge
+
+export TORCH_COMPILE_DISABLE=1
+unset SLURM_MEM_PER_CPU SLURM_MEM_PER_GPU SLURM_MEM_PER_NODE
+export TORCHSTORE_RDMA_ENABLED=0
+
+cd /storage/home/daniellepintz/torchforge
+
+srun python -m apps.grpo.main --config apps/grpo/slurm/${CONFIG_NAME}.yaml
diff --git a/apps/grpo/wandb_llama8b.png b/apps/grpo/wandb_llama8b.png
new file mode 100644
index 000000000..f0fc8eb52
Binary files /dev/null and b/apps/grpo/wandb_llama8b.png differ
diff --git a/apps/mast/README.md b/apps/mast/README.md
deleted file mode 100644
index 60a9b4146..000000000
--- a/apps/mast/README.md
+++ /dev/null
@@ -1,33 +0,0 @@
-# Forge MAST Environment Setup
-
-A simple setup script to automatically configure your environment for running Forge with MAST jobs.
-
-## Quick Start
-
-⚠️ Important Note: the setup script will clone the forge repository under "/data/users/$USER".
-
-### 1. Run the Setup Script
-
-The `env_setup.sh` script will automatically:
-- ✅ Activate the required conda environment (`forge-8448524`)
-- ✅ Clone/update the Forge repository
-- ✅ Install Forge package dependencies
-- ✅ Mount the required oilfs workspace to `/mnt/wsfuse`
-- ✅ Configure your environment for MAST job submission
-
-```bash
-# Make the script executable
-chmod +x env_setup.sh
-
-# Run the setup
-./apps/mast/env_setup.sh
-
-```
-
-### 2. Submit MAST job
-
-```
-pip install --force-reinstall --no-deps . && python -m apps.mast.main --config apps/mast/qwen3_1_7b_mast.yaml
-```
-
-⚠️ Important Note: `pip install --force-reinstall --no-deps .` is required every time you make a change to the local codebase. This ensures your latest changes are installed before job submission.
diff --git a/apps/mast/main.py b/apps/mast/main.py
deleted file mode 100644
index cd5de0be9..000000000
--- a/apps/mast/main.py
+++ /dev/null
@@ -1,67 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-import asyncio
-import getpass
-import uuid
-
-from apps.grpo.main import main as grpo_main
-from forge.cli.config import parse
-from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY
-from forge.controller.provisioner import init_provisioner
-
-from forge.types import (
- Launcher,
- LauncherConfig,
- ProcessConfig,
- ProvisionerConfig,
- ServiceConfig,
-)
-from omegaconf import DictConfig
-
-DEFAULT_CHECKPOINT_FOLDER_KEY = "checkpoint_folder"
-DEFAULT_CHECKPOINT_FOLDER = "/mnt/wsfuse/teamforge/forge_runs/"
-
-
-async def main(cfg: DictConfig):
- """Main module for launching mast jobs for GRPO training."""
- if cfg.get(LAUNCHER_KEY, Launcher.MAST.value) != Launcher.MAST.value:
- raise ValueError("Launcher must be MAST.")
-
- if cfg.get(JOB_NAME_KEY, None) is not None:
- # prepend user name to the job to avoid name collision
- cfg[JOB_NAME_KEY] = f"{getpass.getuser()}-{cfg[JOB_NAME_KEY]}"
- print(f"Overriding mast job name to {cfg[JOB_NAME_KEY]}")
-
- if cfg.get(DEFAULT_CHECKPOINT_FOLDER_KEY, DEFAULT_CHECKPOINT_FOLDER) is not None:
- # append job_name and guid to CP folder path to avoid path collision
- if cfg[DEFAULT_CHECKPOINT_FOLDER_KEY] == DEFAULT_CHECKPOINT_FOLDER:
- cfg[
- DEFAULT_CHECKPOINT_FOLDER_KEY
- ] = f"{cfg[DEFAULT_CHECKPOINT_FOLDER_KEY]}{cfg[JOB_NAME_KEY]}-{uuid.uuid4().hex[:6]}"
- print(f"Overriding checkpoint folder to {cfg[DEFAULT_CHECKPOINT_FOLDER_KEY]}")
-
- # init mast provisioner
- await init_provisioner(
- ProvisionerConfig(
- launcher_config=LauncherConfig(
- launcher=Launcher(cfg.get(LAUNCHER_KEY, Launcher.MAST.value)),
- job_name=cfg.get(JOB_NAME_KEY, None),
- services={k: ServiceConfig(**v) for k, v in cfg.services.items()},
- actors={k: ProcessConfig(**v) for k, v in cfg.actors.items()},
- )
- )
- )
- await grpo_main(cfg)
-
-
-if __name__ == "__main__":
-
- @parse
- def _main(cfg):
- asyncio.run(main(cfg))
-
- _main() # @parse grabs the cfg from CLI
diff --git a/apps/sft/llama3_8b.yaml b/apps/sft/llama3_8b.yaml
index 3d67cdb23..d9a5a9783 100644
--- a/apps/sft/llama3_8b.yaml
+++ b/apps/sft/llama3_8b.yaml
@@ -1,21 +1,22 @@
+# >>> python -m apps.sft.main --config apps/sft/llama3_8b.yaml
-# profiling:
-# enable_profiling: false
-
-metrics:
- logger: tensorboard
- freq:
- loss: 10
+# Config for supervised full finetuning using a Llama3.1 8B Instruct model
# TODO: required by torchtitan
# https://github.com/pytorch/torchtitan/blob/2f1c814da071cc8ad165d00be6f9c1a66f8e1cce/torchtitan/distributed/utils.py#L265
comm:
trace_buf_size: 0
+model_name: "meta-llama/Meta-Llama-3.1-8B-Instruct"
+
model:
name: llama3
flavor: 8B
- hf_assets_path: /tmp/Meta-Llama-3.1-8B-Instruct
+ hf_assets_path: hf://${model_name}
+
+processes:
+ procs: 8
+ with_gpus: true
optimizer:
name: AdamW
@@ -26,24 +27,21 @@ lr_scheduler:
warmup_steps: 200
training:
- local_batch_size: 1
+ local_batch_size: 8
seq_len: 2048
max_norm: 1.0
steps: 1000
compile: false
+ datasets:
+ - path: "yahma/alpaca-cleaned"
+ split: "train[:95%]"
-validation:
- local_batch_size: 1
- freq: -1 # Change to a positive number to enable validation
- steps: 200 # Max steps to run validation. Validation disabled if negative.
-
-dataset:
- path: yahma/alpaca-cleaned
- split: train[:95%]
-
-dataset_val:
- path: yahma/alpaca-cleaned
- split: train[95%:]
+eval:
+ eval_every_n_steps: 50 # null = disabled
+ max_eval_steps: null # null = run until epoch completes
+ datasets:
+ - path: "yahma/alpaca-cleaned"
+ split: "train[95%:]"
parallelism:
data_parallel_replicate_degree: 1
@@ -56,9 +54,9 @@ parallelism:
checkpoint:
enable: true
- folder: /tmp/Meta-Llama-3.1-8B-Instruct/saved_checkpoints
- initial_load_path: /tmp/Meta-Llama-3.1-8B-Instruct/
- initial_load_in_hf: true
+ folder: ./checkpoint # The folder to save checkpoints to.
+ initial_load_path: hf://${model_name} # The path to load the initial checkpoint from. Ignored if `folder` exists.
+ initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo
last_save_in_hf: true
interval: 500
async_mode: "disabled"
@@ -66,3 +64,18 @@ checkpoint:
activation_checkpoint:
mode: selective
selective_ac_option: op
+
+metric_logging:
+ wandb:
+ project: sft-training
+ group: sft_exp_${oc.env:USER}
+ logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
+
+
+# profiling:
+# enable_profiling: false
+
+# metrics:
+# log_freq: 10
+# enable_tensorboard: true
+# save_tb_folder: "tb"
diff --git a/apps/sft/main.py b/apps/sft/main.py
index c806d037a..4f2a7be74 100644
--- a/apps/sft/main.py
+++ b/apps/sft/main.py
@@ -4,26 +4,34 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+"""To run:
+
+python -m apps.sft.main --config apps/sft/llama3_8b.yaml
+
+"""
+
+import asyncio
+import contextlib
+import logging
+import math
import os
import sys
-from dataclasses import asdict
-from functools import partial
from typing import Any
import torch
import torchtitan.experiments.forge.train_spec as forge_train_spec
-from forge.cli.config import parse
-from forge.data.collate import collate_packed
-from forge.data.datasets.packed import PackedDataset, TextPacker
+from forge.controller import ForgeActor
+from forge.data.collate import collate_padded
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
from forge.data.tokenizer import HuggingFaceModelTokenizer
-from forge.data.utils import batch_to_device, CROSS_ENTROPY_IGNORE_IDX
-from forge.util import get_metric_logger
+from forge.data.utils import StopAfterOneEpoch
+from forge.observability import get_or_create_metric_logger, record_metric, Reduce
+from forge.util.config import parse
+from monarch.actor import current_rank, current_size, endpoint
from omegaconf import DictConfig, OmegaConf
from torch import nn
-
from torchdata.stateful_dataloader import StatefulDataLoader
from torchtitan.components.loss import LossFunction
from torchtitan.components.lr_scheduler import LRSchedulersContainer
@@ -31,9 +39,8 @@
from torchtitan.distributed import ParallelDims, utils as dist_utils
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig
-from torchtitan.models.attention import init_attention_mask
-from tqdm import tqdm
+# from tqdm import tqdm
# stubs for now
Checkpointer = Any
@@ -42,8 +49,11 @@
Profiler = Any
Tokenizer = Any
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
-class ForgeSFTRecipe(ForgeEngine):
+class ForgeSFTRecipe(ForgeActor, ForgeEngine):
job_config: ForgeJobConfig
train_spec: forge_train_spec.ForgeTrainSpec
parallel_dims: ParallelDims
@@ -60,41 +70,106 @@ class ForgeSFTRecipe(ForgeEngine):
device: torch.device
step: int
- def __init__(self, job_config: ForgeJobConfig):
+ def __init__(self, config: DictConfig):
+ job_config = ForgeJobConfig().to_dict()
+ # Hack to deal with literal types from titan
+ job_config = OmegaConf.merge(job_config, config)
+
self.current_step = 0
self.num_training_steps = job_config.training.steps
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
+ self._rank = current_rank().rank
+ self._size = math.prod(current_size().values())
super().__init__(job_config)
- self.metric_logger = get_metric_logger(**job_config.metrics)
- def setup(self):
- self.train_dataloader = self.setup_data(
- self.job_config.dataset,
- batch_size=self.job_config.training.local_batch_size,
- )
+ async def setup_metric_logger(self):
+ """Initialization happens in the main process. Here we just retrieve it"""
+ mlogger = await get_or_create_metric_logger()
+ return mlogger
+
+ def record_batch_metrics(self, data_metrics: list):
+ """Since the dataloader creates new processes, we dont call `record_metric` in the dataset.
+ Instead, pop the metrics from the batch and record them here."""
+ for metric in data_metrics:
+ record_metric(metric.key, metric.value, metric.reduction)
+
+ @endpoint
+ async def setup(self):
+ # Validate that compile is only used with flex attention
+ if self.job_config.training.compile:
+ raise ValueError(
+ "training.compile=True is not currently supported. "
+ "Compile is only supported with flex attention enabled, which requires PyTorch nightly. "
+ "Please set training.compile=false in your config."
+ )
- self.val_dataloader = self.setup_data(
- self.job_config.dataset_val,
- batch_size=self.job_config.validation.local_batch_size,
+ # all ranks should record loss, except when PP=True. Then, only the last stage should record loss.
+ self.rank_should_record_loss = True
+ if hasattr(self, "pp_has_last_stage") and not self.pp_has_last_stage:
+ self.rank_should_record_loss = False
+
+ # metric logger
+ self.mlogger = await self.setup_metric_logger()
+
+ # Load training datasets
+ logger.info("Setting training datasets")
+ train_datasets_config = self.job_config.training.datasets
+ self.train_dataloader = self.setup_data(train_datasets_config)
+
+ # Load eval datasets
+ eval_config = self.job_config["eval"]
+ self.val_dataloaders = {}
+ self.eval_every_n_steps = eval_config["eval_every_n_steps"]
+ max_eval_steps = eval_config["max_eval_steps"]
+ self.max_eval_steps = (
+ max_eval_steps if max_eval_steps and max_eval_steps > 0 else None
+ )
+ self.validation_enabled = (
+ self.eval_every_n_steps is not None and self.eval_every_n_steps > 0
)
+ if self.validation_enabled:
+ logger.info("Setting eval datasets")
+ self.eval_datasets_config = eval_config.datasets
- # self.train_dataloader = self.setup_data(
- # self.train_config.train_dataset_config,
- # self.train_config.train_dataloader_config,
- # self.train_config.packing_config,
- # )
- # self.val_dataloader = self.setup_data(
- # self.train_config.val_dataset_config,
- # self.train_config.val_dataloader_config,
- # self.train_config.packing_config,
- # )
+ for i, dataset_config in enumerate(self.eval_datasets_config):
+ ds_name = dataset_config.get("dataset_name", i)
+ # TODO: Support separate eval batch size from config (eval.local_batch_size)
+ dataloader = self.setup_data([dataset_config])
+ self.val_dataloaders[ds_name] = dataloader
+
+ # TODO: confirm that this is working properly
+ # Should also use load, not dcp_load
self.checkpointer.load(step=self.current_step)
+
# self.profiler = self.setup_profiler(self.train_config.profiler_config)
# self.logger = self.setup_logger(self.train_config.logger_config)
- def setup_data(self, dataset_config, batch_size):
- self.tokenizer = HuggingFaceModelTokenizer(
+ def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader:
+ """Instantiates datasets and returns a StatefulDataLoader.
+
+ Args:
+ dataset_configs (list[dict]): List of dataset config dicts used as `sft_iterable_dataset(**dataset_configs[i])`.
+
+ Returns:
+ StatefulDataLoader
+
+ Raises:
+ ValueError: If multiple datasets provided (not yet supported)
+ """
+
+ # TODO felipemello: Currently only support single dataset
+ if len(dataset_configs) > 1:
+ raise ValueError(
+ f"Multiple training datasets not supported yet. "
+ f"Got {len(dataset_configs)} datasets. "
+ )
+
+ dataset_config = dataset_configs[0]
+
+ # TODO: Evaluate if tokenizers should be created once and shared for every dataset
+ # Load tokenizer
+ tokenizer = HuggingFaceModelTokenizer(
tokenizer_json_path=os.path.join(
self.job_config.model.hf_assets_path, "tokenizer.json"
),
@@ -104,39 +179,43 @@ def setup_data(self, dataset_config, batch_size):
generation_config_path=os.path.join(
self.job_config.model.hf_assets_path, "generation_config.json"
),
+ chat_template_path=(
+ path
+ if os.path.exists(
+ path := os.path.join(
+ self.job_config.model.hf_assets_path, "chat_template.jinja"
+ )
+ )
+ else None
+ ),
)
+ # Get DP mesh for data sharding
+ dp_mesh = None
+ if self.parallel_dims is not None and self.parallel_dims.dp_enabled:
+ dp_mesh = self.parallel_dims.world_mesh.get_group("dp")
+
+ # Pass config directly to dataset constructor
dataset = sft_iterable_dataset(
- model_transform=self.tokenizer,
+ model_transform=tokenizer,
message_transform=AlpacaToMessages(),
- path=dataset_config.path,
- split=dataset_config.split,
- )
- packer = TextPacker(padding_idx=0)
- dataset = PackedDataset(
- dataset=dataset,
- packer=packer,
- target_tokens_per_pack=self.job_config.training.seq_len, # TODO: get this from model
+ dp_mesh=dp_mesh,
+ **dataset_config,
)
+
dataloader = StatefulDataLoader(
dataset=dataset,
- batch_size=batch_size,
- collate_fn=partial(
- collate_packed, mask_fn=packer.create_block_mask, device=self.device
- ),
+ batch_size=self.job_config.training.local_batch_size,
+ collate_fn=collate_padded,
)
- # Ultimately we probably want something like this
- # packer = build_packing_strategy(packing_config)
- # dataset = build_dataset(dataset_config)
- # dataloader = build_dataloader(dataloader_config, dataset, packer)
return dataloader
def forward_backward(
self,
input_dict: dict[str, torch.Tensor],
labels: torch.Tensor,
- do_backward: bool = True,
+ skip_backward: bool = False,
) -> torch.Tensor:
model_parts = self.model_parts
parallel_dims = self.parallel_dims
@@ -144,13 +223,6 @@ def forward_backward(
# apply context parallelism if cp is enabled
# ensure CP handles the separate freqs_cis buffer for each pp stage
inputs = input_dict["tokens"]
-
- if getattr(self.model_args, "use_flex_attn", False):
- cp_mesh = (
- parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None
- )
- init_attention_mask(inputs, self.tokenizer.base_tokenizer.eos_id, cp_mesh)
-
optional_context_parallel_ctx = (
dist_utils.create_context_parallel_ctx(
cp_mesh=parallel_dims.world_mesh["cp"],
@@ -169,24 +241,23 @@ def forward_backward(
targets, losses = (
(labels, []) if self.pp_has_last_stage else (None, None)
)
- if do_backward:
- pp_schedule_fn = self.pp_schedule.step
- else:
- pp_schedule_fn = self.pp_schedule.eval
if self.pp_has_first_stage:
- pp_schedule_fn(
- inputs, target=targets, losses=losses, input_batch=inputs
- )
+ self.pp_schedule.step(inputs, target=targets, losses=losses)
else:
- pp_schedule_fn(target=targets, losses=losses, input_batch=inputs)
+ self.pp_schedule.step(target=targets, losses=losses)
# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
loss = (
- torch.mean(torch.stack(losses)).to(self.device)
+ torch.sum(torch.stack(losses)).to(self.device)
if self.pp_has_last_stage
- else torch.tensor([-1.0], device=self.device)
+ else torch.tensor(-1.0, device=self.device)
)
+
+ # TODO: PP requires gradients enabled and cant deactive with no_grad
+ if skip_backward:
+ loss = loss.detach()
+
else:
# Non-PP forward / backward
with self.train_context(optional_context_parallel_ctx):
@@ -196,7 +267,9 @@ def forward_backward(
loss = self.loss_fn(pred, labels)
# need to free to before bwd to avoid peaking memory
del pred
- if do_backward:
+
+ # Only run backward if requested. Useful for eval.
+ if not skip_backward:
loss.backward()
return loss
@@ -210,103 +283,222 @@ def train_step(self, batch) -> None:
# ) as grad_acc:
labels = batch.pop("labels")
loss = self.forward_backward(batch, labels)
- self.pbar.update(1)
- self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
- self.metric_logger.log("loss", loss.item(), self.current_step)
+ if self.rank_should_record_loss:
+ loss_val = loss.item()
+ record_metric("ForgeSFTRecipe/train_step/loss", loss_val, Reduce.MEAN)
+ logger.info(
+ f"step {self.current_step} / {self.num_training_steps} | Loss: {loss_val}"
+ )
+
+ # self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
+ # self.pbar.update(1)
self.optimizers.step()
- self.optimizers.zero_grad()
self.lr_schedulers.step()
- def train(self) -> None:
+ async def evaluate(self) -> None:
+ """Run evaluation on multiple datasets, one at a time.
+
+ 1. Set models to eval mode
+ 2. For each eval dataset:
+ - Create fresh iterator (starts from epoch 0)
+ - Use StopAfterOneEpoch to iterate until epoch boundary. This utility
+ is necessary for infinite iterable dataset, since epoch boundaries are not known.
+ - Respect max_eval_steps cap if configured
+ - Record loss and step metrics (on dp rank only)
+ 3. Restore models to train mode
+ """
+
+ # Set models to eval mode
+ for model_part in self.model_parts:
+ model_part.eval()
+
+ # Get DP process group for epoch synchronization
+ dp_mesh = None
+ if self.parallel_dims is not None and self.parallel_dims.dp_enabled:
+ dp_mesh = self.parallel_dims.world_mesh.get_group("dp")
+
+ # For non-PP: disable gradients to save memory
+ # TODO: For PP, if disabling gradients, throws error
+ maybe_no_grad = (
+ contextlib.nullcontext()
+ if self.parallel_dims.pp_enabled
+ else torch.no_grad()
+ )
+
+ # Evaluate each dataset sequentially
+ all_dataset_losses = []
+ all_dataset_steps = []
+ for dataset_name, val_dataloader in self.val_dataloaders.items():
+ logger.info(f"=====Evaluating dataset: {dataset_name}=====")
+
+ # Evaluation loop for this dataset
+ total_loss = torch.tensor(0.0, device=self.device)
+ num_steps = 0
+
+ # NOTE: Assumes batch contains field "metrics" containing "num_epochs"
+ batch_iter = StopAfterOneEpoch(
+ iter=iter(val_dataloader), # Fresh iterator from epoch 0,
+ device=self.device,
+ dp_mesh=dp_mesh,
+ )
+
+ with maybe_no_grad:
+ for batch in batch_iter:
+ # if max_eval_steps>len(dataset), it will be stopped earlier by StopAfterOneEpoch.
+ if (
+ self.max_eval_steps is not None
+ and num_steps >= self.max_eval_steps
+ ):
+ logger.info(
+ f"[{dataset_name}] Reached max_eval_steps cap of {self.max_eval_steps}"
+ )
+ break
+
+ # Move tensors to device
+ for key, value in batch.items():
+ if isinstance(value, torch.Tensor):
+ batch[key] = value.to(self.device)
+
+ # Process batch
+ labels = batch.pop("labels")
+ loss = self.forward_backward(batch, labels, skip_backward=True)
+ total_loss += loss
+ num_steps += 1
+
+ # Log progress
+ if self.rank_should_record_loss:
+ loss_val = loss.item()
+ logger.info(
+ f"[dataset {dataset_name}] Step {num_steps} | Loss: {loss_val:.4f}"
+ )
+
+ # log loss
+ avg_loss = (total_loss / max(num_steps, 1)).item()
+ all_dataset_losses.append(avg_loss)
+ all_dataset_steps.append(num_steps)
+ logger.info(
+ f"[dataset {dataset_name}] Final Step {num_steps} | Avg Loss: {avg_loss:.4f}"
+ )
+ if self.rank_should_record_loss:
+ record_metric(
+ f"evaluate/dataset_{dataset_name}_avg_loss",
+ avg_loss,
+ Reduce.MEAN,
+ )
+
+ # Record macro and micro average losses across datasets (only if multiple datasets)
+ if self.rank_should_record_loss and len(all_dataset_losses) > 1:
+ # Macro: same weight for all datasets
+ macro_avg_loss = sum(all_dataset_losses) / len(all_dataset_losses)
+ record_metric("evaluate/macro_avg_loss", macro_avg_loss, Reduce.MEAN)
+
+ # Micro: weighted mean by dataset size
+ total_steps = sum(all_dataset_steps)
+ micro_avg_loss = (
+ sum(
+ loss * steps
+ for loss, steps in zip(all_dataset_losses, all_dataset_steps)
+ )
+ / total_steps
+ )
+ record_metric("evaluate/micro_avg_loss", micro_avg_loss, Reduce.MEAN)
+
+ logger.info(
+ f"Macro avg loss (unweighted): {macro_avg_loss:.4f}, "
+ f"Micro avg loss (weighted): {micro_avg_loss:.4f}"
+ )
+
+ # Restore train mode
+ for model_part in self.model_parts:
+ model_part.train()
+
+ logger.info("==Evaluation complete==")
+
+ @endpoint
+ async def train(self) -> None:
dataloader = iter(self.train_dataloader)
self.optimizers.zero_grad()
- self.pbar = tqdm(
- initial=0,
- total=self.num_training_steps,
- desc=f"{self.current_step}",
- )
+ # TODO: tqdm is broken in Monarch actors
+ # self.pbar = tqdm(initial=self.current_step, total=self.num_training_steps)
while self.current_step < self.num_training_steps:
batch = next(dataloader)
+
+ # Pop and record metrics from batch before moving to device
+ self.record_batch_metrics(batch.pop("metrics", []))
+ record_metric("ForgeSFTRecipe/train/step", self.current_step, Reduce.MEAN)
+
# Move tensors to the appropriate device
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to("cuda") # TODO: hardcoded for now
+
self.train_step(batch)
# self.profiler.step()
self.current_step += 1
+ # Run evaluation periodically if enabled
+ if (
+ self.validation_enabled
+ and self.current_step % self.eval_every_n_steps == 0
+ ):
+ await self.evaluate()
+
self.checkpointer.save(
curr_step=self.current_step,
last_step=self.current_step == self.num_training_steps,
)
- if (
- self.job_config.validation.freq > 0
- and self.job_config.validation.steps > 0
- and self.current_step % self.job_config.validation.freq == 0
- ):
- self.validate(self.job_config.validation.steps)
-
- def validate(self, max_steps: int) -> None:
- for m in self.model_parts:
- m.eval()
- total_val_loss = torch.tensor(0.0, device=self.device)
- total_val_tokens = torch.tensor(0.0, device=self.device)
- with torch.no_grad():
- val_pbar = tqdm(self.val_dataloader, desc="Validation", leave=False)
- for batch_idx, batch in enumerate(val_pbar):
- if batch_idx >= max_steps:
- break
- batch_to_device(batch, self.device)
- current_num_tokens = (batch["labels"] != CROSS_ENTROPY_IGNORE_IDX).sum()
- # Compute loss
- labels = batch.pop("labels")
- loss = self.forward_backward(batch, labels, do_backward=False)
- val_loss = loss * current_num_tokens
- total_val_loss += val_loss
- total_val_tokens += current_num_tokens
- # Update progress bar description with current average loss
- avg_loss_so_far = (
- (total_val_loss / total_val_tokens).item()
- if total_val_tokens > 0
- else float("inf")
- )
- val_pbar.set_description(
- f"Running validation Loss: {avg_loss_so_far:.4f}"
- )
- # Aggregate validation metrics across all ranks
- torch.distributed.all_reduce(total_val_loss)
- torch.distributed.all_reduce(total_val_tokens)
- avg_val_loss = (
- (total_val_loss / total_val_tokens).item()
- if total_val_tokens > 0
- else float("inf")
- )
- for m in self.model_parts:
- m.train()
- print(f"\nValidation loss: {avg_val_loss}")
+ # Flush metrics
+ if self._rank == 0:
+ await self.mlogger.flush.call_one(global_step=self.current_step)
+
+ # self.pbar.close()
+
+ if self.validation_enabled:
+ logger.info("Running final evaluation at end of training...")
+ await self.evaluate()
- def cleanup(self) -> None:
+ @endpoint
+ async def cleanup(self) -> None:
if self.checkpointer:
self.checkpointer.close()
- if self.metric_logger:
- self.metric_logger.close()
+ if getattr(self, "mlogger", None):
+ await self.mlogger.shutdown.call_one()
+
+ def __repr__(self) -> str:
+ return "Trainer"
+
+
+async def run(cfg: DictConfig) -> None:
+ logging.info("Spawning recipe...")
+ process_cfg = cfg.pop("processes")
+
+ # Initialize metric logger in main process
+ metric_logging_cfg = cfg.get("metric_logging", {})
+ mlogger = await get_or_create_metric_logger(process_name="Controller")
+ await mlogger.init_backends.call_one(metric_logging_cfg)
+
+ recipe = await ForgeSFTRecipe.options(**process_cfg).as_actor(cfg)
+
+ logging.info("Created recipe, running setup.")
+ await recipe.setup.call()
+
+ logging.info("Recipe has been setup. Training now.")
+ await recipe.train.call()
+
+ logging.info("Done training. Clean up")
+ await recipe.cleanup.call()
+
+ await recipe.mesh.stop()
+ logging.info("All done!")
@parse
def recipe_main(cfg: DictConfig) -> None:
- # TODO: this is a hack to get the defaults from ForgeJobConfig
- default_cfg = ForgeJobConfig()
- # Hack to deal with literal types from titan
- default_cfg = asdict(default_cfg)
- cfg = OmegaConf.merge(default_cfg, cfg)
- recipe = ForgeSFTRecipe(cfg)
- recipe.setup()
- recipe.train()
- recipe.cleanup()
+ asyncio.run(run(cfg))
if __name__ == "__main__":
diff --git a/apps/sft_v2/llama3_8b.yaml b/apps/sft/qwen3_8b.yaml
similarity index 50%
rename from apps/sft_v2/llama3_8b.yaml
rename to apps/sft/qwen3_8b.yaml
index 86fd88ca5..a8d2244a9 100644
--- a/apps/sft_v2/llama3_8b.yaml
+++ b/apps/sft/qwen3_8b.yaml
@@ -1,11 +1,4 @@
-# >>> python -m apps.sft_v2.main --config apps/sft_v2/llama3_8b.yaml
-
-# Config for supervised full finetuning using a Llama3.1 8B Instruct model
-#
-# This config assumes that you've run the following command before launching
-# this run:
-# export HF_HUB_DISABLE_XET=1
-# forge download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct
+# >>> python -m apps.sft.main --config apps/sft/qwen3_8b.yaml
# TODO: required by torchtitan
@@ -13,10 +6,12 @@
comm:
trace_buf_size: 0
+model_name: "Qwen/Qwen3-8B"
+
model:
- name: llama3
+ name: qwen3
flavor: 8B
- hf_assets_path: /tmp/Meta-Llama-3.1-8B-Instruct
+ hf_assets_path: hf://${model_name}
processes:
procs: 8
@@ -31,12 +26,21 @@ lr_scheduler:
warmup_steps: 200
training:
- local_batch_size: 1
+ local_batch_size: 8
seq_len: 2048
max_norm: 1.0
steps: 1000
compile: false
- dataset: "c4"
+ datasets:
+ - path: "yahma/alpaca-cleaned"
+ split: "train[:95%]"
+
+eval:
+ eval_every_n_steps: 50 # null = disabled
+ max_eval_steps: null # null = run until epoch completes
+ datasets:
+ - path: "yahma/alpaca-cleaned"
+ split: "train[95%:]"
parallelism:
data_parallel_replicate_degree: 1
@@ -49,9 +53,9 @@ parallelism:
checkpoint:
enable: true
- folder: /tmp/Meta-Llama-3.1-8B-Instruct/saved_checkpoints
- initial_load_path: /tmp/Meta-Llama-3.1-8B-Instruct/
- initial_load_in_hf: true
+ folder: ./checkpoint # The folder to save checkpoints to.
+ initial_load_path: hf://${model_name} # The path to load the initial checkpoint from. Ignored if `folder` exists.
+ initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo
last_save_in_hf: true
interval: 500
async_mode: "disabled"
@@ -60,6 +64,12 @@ activation_checkpoint:
mode: selective
selective_ac_option: op
+metric_logging:
+ wandb:
+ project: sft-training
+ group: sft_exp_${oc.env:USER}
+ logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
+
# profiling:
# enable_profiling: false
diff --git a/apps/sft_v2/main.py b/apps/sft_v2/main.py
deleted file mode 100644
index 61b27baa3..000000000
--- a/apps/sft_v2/main.py
+++ /dev/null
@@ -1,303 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-"""To run:
-
-python -m apps.sft_v2.main --config apps/sft_v2/llama3_8b.yaml
-
-"""
-
-import asyncio
-
-import logging
-import math
-import os
-import sys
-from functools import partial
-from typing import Any
-
-import torch
-
-import torchtitan.experiments.forge.train_spec as forge_train_spec
-from forge.cli.config import parse
-from forge.controller import ForgeActor
-from forge.data.collate import collate_packed
-from forge.data.datasets.packed import PackedDataset, TextPacker
-from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
-from forge.data.tokenizer import HuggingFaceModelTokenizer
-
-from monarch.actor import current_rank, current_size, endpoint
-from omegaconf import DictConfig, OmegaConf
-from torch import nn
-from torchdata.stateful_dataloader import StatefulDataLoader
-from torchtitan.components.loss import LossFunction
-from torchtitan.components.lr_scheduler import LRSchedulersContainer
-from torchtitan.components.optimizer import OptimizersContainer
-from torchtitan.distributed import ParallelDims, utils as dist_utils
-from torchtitan.experiments.forge.engine import ForgeEngine
-from torchtitan.experiments.forge.job_config import ForgeJobConfig
-
-# from tqdm import tqdm
-
-# stubs for now
-Checkpointer = Any
-Dataloader = Any
-MetricLogger = Any
-Profiler = Any
-Tokenizer = Any
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-
-
-class ForgeSFTRecipe(ForgeActor, ForgeEngine):
- job_config: ForgeJobConfig
- train_spec: forge_train_spec.ForgeTrainSpec
- parallel_dims: ParallelDims
- model: list[nn.Module]
- loss_fn: LossFunction
- optimizer: OptimizersContainer
- lr_scheduler: LRSchedulersContainer
- checkpointer: Checkpointer
- tokenizer: Tokenizer
- train_dataloader: Dataloader
- # val_dataloader: Dataloader
- metric_logger: MetricLogger
- profiler: Profiler
- device: torch.device
- step: int
-
- def __init__(self, config: DictConfig):
- job_config = ForgeJobConfig().to_dict()
- # Hack to deal with literal types from titan
- job_config = OmegaConf.merge(job_config, config)
-
- self.current_step = 0
- self.num_training_steps = job_config.training.steps
- self.metric_logger = None # TODO: fix this
- self.gradient_accumulation_steps = 1 # Example value, adjust as needed
- self._rank = current_rank().rank
- self._size = math.prod(current_size().values())
- self._init_dist()
- super().__init__(job_config)
-
- def _init_dist(self):
- """Initializes torch distributed.
-
- torchrun normally hands this, but we need to do it ourselves
- in monarch for now.
-
- We should consider putting this into ForgeActor, but having this
- be explicit for now.
-
- """
- env = {
- "RANK": str(self._rank),
- "LOCAL_RANK": str(self._rank),
- "LOCAL_WORLD_SIZE": str(self._size),
- "GROUP_RANK": str(self._size),
- "GROUP_WORLD_SIZE": str(self._size),
- "ROLE_RANK": str(self._rank),
- "ROLE_WORLD_SIZE": str(self._size),
- "ROLE_NAME": "rank",
- "WORLD_SIZE": str(self._size),
- "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
- }
- os.environ.update(env)
- logger.info("env: {}".format(env))
-
- @endpoint
- async def setup(self):
- self.train_dataloader = self.setup_data()
- # self.train_dataloader = self.setup_data(
- # self.train_config.train_dataset_config,
- # self.train_config.train_dataloader_config,
- # self.train_config.packing_config,
- # )
- # self.val_dataloader = self.setup_data(
- # self.train_config.val_dataset_config,
- # self.train_config.val_dataloader_config,
- # self.train_config.packing_config,
- # )
-
- # TODO: confirm that this is working properly
- # Should also use load, not dcp_load
- self.checkpointer.load(step=self.current_step)
- # self.profiler = self.setup_profiler(self.train_config.profiler_config)
- # self.logger = self.setup_logger(self.train_config.logger_config)
-
- def setup_data(self):
- print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json"))
- tokenizer = HuggingFaceModelTokenizer(
- tokenizer_json_path=os.path.join(
- self.job_config.model.hf_assets_path, "tokenizer.json"
- ),
- tokenizer_config_json_path=os.path.join(
- self.job_config.model.hf_assets_path, "tokenizer_config.json"
- ),
- generation_config_path=os.path.join(
- self.job_config.model.hf_assets_path, "generation_config.json"
- ),
- )
-
- dataset = sft_iterable_dataset(
- model_transform=tokenizer,
- message_transform=AlpacaToMessages(),
- path="yahma/alpaca-cleaned",
- split="train",
- )
- packer = TextPacker(padding_idx=0)
- dataset = PackedDataset(
- dataset=dataset,
- packer=packer,
- target_tokens_per_pack=self.job_config.training.seq_len, # TODO: get this from model
- )
- dataloader = StatefulDataLoader(
- dataset=dataset,
- batch_size=self.job_config.training.local_batch_size,
- collate_fn=partial(
- collate_packed, mask_fn=packer.create_block_mask, device=self.device
- ),
- )
-
- # Ultimately we probably want something like this
- # packer = build_packing_strategy(packing_config)
- # dataset = build_dataset(dataset_config)
- # dataloader = build_dataloader(dataloader_config, dataset, packer)
- return dataloader
-
- def forward_backward(
- self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
- ) -> torch.Tensor:
- model_parts = self.model_parts
- parallel_dims = self.parallel_dims
-
- # apply context parallelism if cp is enabled
- # ensure CP handles the separate freqs_cis buffer for each pp stage
- inputs = input_dict["tokens"]
- optional_context_parallel_ctx = (
- dist_utils.create_context_parallel_ctx(
- cp_mesh=parallel_dims.world_mesh["cp"],
- cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
- cp_seq_dims=[1, 1] + [0 for _ in model_parts],
- cp_no_restore_buffers={inputs, labels},
- cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
- )
- if parallel_dims.cp_enabled
- else None
- )
-
- if parallel_dims.pp_enabled:
- # Pipeline Parallel forward / backward inside step() call
- with self.train_context(optional_context_parallel_ctx):
- targets, losses = (
- (labels, []) if self.pp_has_last_stage else (None, None)
- )
- if self.pp_has_first_stage:
- self.pp_schedule.step(
- inputs, target=targets, losses=losses, input_batch=inputs
- )
- else:
- self.pp_schedule.step(
- target=targets, losses=losses, input_batch=inputs
- )
-
- # accumulate losses across pipeline microbatches
- # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
- loss = (
- torch.mean(torch.stack(losses)).to(self.device)
- if self.pp_has_last_stage
- else torch.tensor([-1.0], device=self.device)
- )
- else:
- # Non-PP forward / backward
- with self.train_context(optional_context_parallel_ctx):
- assert len(model_parts) == 1
- with self.maybe_enable_amp:
- pred = model_parts[0](inputs)
- loss = self.loss_fn(pred, labels)
- # need to free to before bwd to avoid peaking memory
- del pred
- loss.backward()
-
- return loss
-
- def train_step(self, batch) -> None:
- # TODO
- # with GradientAccumulation(
- # self.gradient_accumulation_steps,
- # self.model,
- # self.data_parallel_size,
- # ) as grad_acc:
- labels = batch.pop("labels")
- loss = self.forward_backward(batch, labels)
-
- logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}")
- # self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
- # self.pbar.update(1)
- self.optimizers.step()
- self.lr_schedulers.step()
-
- @endpoint
- async def train(self) -> None:
- dataloader = iter(self.train_dataloader)
- self.optimizers.zero_grad()
-
- # TODO: tqdm is broken in Monarch actors
- # self.pbar = tqdm(initial=self.current_step, total=self.num_training_steps)
-
- while self.current_step < self.num_training_steps:
- batch = next(dataloader)
- # Move tensors to the appropriate device
- for k, v in batch.items():
- if isinstance(v, torch.Tensor):
- batch[k] = v.to("cuda") # TODO: hardcoded for now
- self.train_step(batch)
- # self.profiler.step()
- self.current_step += 1
-
- self.checkpointer.save(
- curr_step=self.current_step,
- last_step=self.current_step == self.num_training_steps,
- )
-
- # self.pbar.close()
-
- @endpoint
- async def cleanup(self) -> None:
- if self.checkpointer:
- self.checkpointer.close()
- if self.metric_logger:
- self.metric_logger.close()
-
- def __repr__(self) -> str:
- return "Trainer"
-
-
-async def run(cfg: DictConfig) -> None:
- logging.info("Spawing recipe...")
- process_cfg = cfg.pop("processes")
- recipe = await ForgeSFTRecipe.options(**process_cfg).as_actor(cfg)
-
- logging.info("Created recipe, running setup.")
- await recipe.setup.call()
-
- logging.info("Recipe has been setup. Training now.")
- await recipe.train.call()
-
- logging.info("Done training. Clean up")
- await recipe.cleanup.call()
- await recipe.mesh.stop()
- logging.info("All done!")
-
-
-@parse
-def recipe_main(cfg: DictConfig) -> None:
- asyncio.run(run(cfg))
-
-
-if __name__ == "__main__":
- sys.exit(recipe_main())
diff --git a/assets/ci/monarch_no_torch-0.1.0.dev20250826-py3-none-any.whl b/assets/ci/monarch_no_torch-0.1.0.dev20250826-py3-none-any.whl
deleted file mode 100644
index 4d3eaeb36..000000000
Binary files a/assets/ci/monarch_no_torch-0.1.0.dev20250826-py3-none-any.whl and /dev/null differ
diff --git a/assets/versions.sh b/assets/versions.sh
new file mode 100644
index 000000000..333fb9116
--- /dev/null
+++ b/assets/versions.sh
@@ -0,0 +1,18 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Version Configuration for Forge Wheel Building
+# This file contains all pinned versions and commits for dependencies
+
+# Stable versions of upstream libraries for OSS repo
+PYTORCH_VERSION="2.9.0"
+VLLM_VERSION="v0.10.0"
+MONARCH_VERSION="0.1.2"
+TORCHTITAN_VERSION="0.2.0"
+TORCHSTORE_VERSION="0.1.2"
+
+# Torchtitan commit hash for launching on MAST
+TORCHTITAN_COMMIT_MAST="d0e25450bcac2332359b13fbda430dc701f073d4"
diff --git a/assets/wheels/monarch-0.0.1-cp310-cp310-linux_x86_64.whl b/assets/wheels/monarch-0.0.1-cp310-cp310-linux_x86_64.whl
deleted file mode 100644
index 146e04a27..000000000
Binary files a/assets/wheels/monarch-0.0.1-cp310-cp310-linux_x86_64.whl and /dev/null differ
diff --git a/assets/wheels/torchstore-0.1.0-py3-none-any.whl b/assets/wheels/torchstore-0.1.0-py3-none-any.whl
deleted file mode 100644
index dd360e129..000000000
Binary files a/assets/wheels/torchstore-0.1.0-py3-none-any.whl and /dev/null differ
diff --git a/assets/wheels/torchtitan-0.1.0-py3-none-any.whl b/assets/wheels/torchtitan-0.1.0-py3-none-any.whl
deleted file mode 100644
index cc61d6db1..000000000
Binary files a/assets/wheels/torchtitan-0.1.0-py3-none-any.whl and /dev/null differ
diff --git a/demo.ipynb b/demo.ipynb
deleted file mode 100644
index 459fa0654..000000000
--- a/demo.ipynb
+++ /dev/null
@@ -1,677 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "id": "ed6165f4-0038-40d3-bf65-419fcf61af24",
- "metadata": {},
- "source": [
- "# Intro to Forge\n",
- "\n",
- "Forge is a PyTorch-native framework designed for rapid experimentation and large-scale training of Reinforcement Learning (RL) algorithms with Large Language Models (LLMs). It's designed to:\n",
- "- Express RL algorithms as naturally as psuedocode, while scaling seamlessly across clusters\n",
- "- Support varying degrees of asynchrony - from fully synchronous/on-policy, to fully asynchronous/off-policy training\n",
- "- Separate infrastructural concerns from algorithmic implementation\n",
- "- Bias towards composable, reusable components that can be mixed and matched for different RL approaches\n",
- "\n",
- "Forge is built on top of proven components:\n",
- "- **[Monarch](https://github.com/meta-pytorch/monarch)** - PyTorch-native single-controller framework\n",
- "- **[torchtitan](https://github.com/pytorch/torchtitan)** - PyTorch-native large-scale LLM training platform\n",
- "- **[vLLM](https://github.com/vllm-project/vllm)** - A high-throughput, memory efficient inference and serving engine for LLMs\n",
- "\n",
- "Our mission is to accelerate innovation in reinforcement learning by empowering researchers and developers to explore new RL algorithms and infrastructure techniques. Whether you're designing novel training methods or optimizing distributed systems, Forge provides a foundation to build upon."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "24de9912-ed10-4729-9616-2f85bbf64e43",
- "metadata": {},
- "source": [
- "## Brief Intro to Monarch\n",
- "Before diving into Forge, we need to first establish the foundation. Forge is built on top of Monarch, PyTorch's native single-controller framework for distributed execution.\n",
- "\n",
- "Forge builds many of its abstractions on top of Monarch, so it's worth introducing a few of its key concepts first. The following sections borrow from Monarch's Getting Started Guide (not public yet).\n",
- "\n",
- "### Defining an Actor\n",
- "At its core, Monarch uses [actors](https://en.wikipedia.org/wiki/Actor_model) as a way to create multi-machine programs. Actors are Python objects that expose a number of endpoint functions. These functions can be called by other actors in the system and their responses gathered asynchronously."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "334ff5d6-bf9f-4b53-a04e-b488083a8101",
- "metadata": {},
- "outputs": [],
- "source": [
- "import asyncio\n",
- "from monarch.actor import Actor, endpoint, this_proc\n",
- "\n",
- "class Counter(Actor):\n",
- " def __init__(self, initial_value: int):\n",
- " self.value = initial_value\n",
- "\n",
- " @endpoint\n",
- " def increment(self) -> None:\n",
- " self.value += 1\n",
- "\n",
- " @endpoint\n",
- " def get_value(self) -> int:\n",
- " return self.value\n"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "8b2815f7-ad6f-4928-b262-033c9b5cb847",
- "metadata": {},
- "source": [
- "The decorator `@endpoint` specifies functions of the Actor that can be called remotely from other actors.\n",
- "\n",
- "### Spawning An Actor In The Local Process\n",
- "\n",
- "We spawn actors in the current running process like so:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "a9e22453-d877-4334-8d80-b3bc1de85455",
- "metadata": {},
- "outputs": [],
- "source": [
- "counter: Counter = this_proc().spawn(\"counter\", Counter, initial_value=0)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "b8a9aa84-2ef4-4961-9173-4fee73c065c5",
- "metadata": {},
- "source": [
- "`this_proc()` is a handle to a process, giving us direct control over where an actor runs. Monarch is very literal about where things are run, so that code can be written in the most efficient way. \n",
- "\n",
- "### Sending A Simple Message\n",
- "Once an actor is spawned, we can send messages to the actor:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "7d20ada2-9084-4439-bd9b-e95881cf7009",
- "metadata": {},
- "outputs": [],
- "source": [
- "from monarch.actor import Future\n",
- "\n",
- "fut: Future[int] = counter.get_value.call_one()\n",
- "\n",
- "value = await fut\n",
- "\n",
- "print(f\"Counter value: {value}\")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "6bd15594-380a-44a1-977c-c536b6ae3a9c",
- "metadata": {},
- "source": [
- "Here we invoke the `get_value` message, returning 0, the current value of the Counter. `call_one` here is referred to as an \"adverb\" because it modified how results of the endpoint are handled. `call_one` just invokes a single actor and gets its value.\n",
- "\n",
- "Notice that the return value is a `Future[int]` - the message is sent asynchronously, letting the sender do other things before it needs the reply. We can `await` on the result.\n",
- "\n",
- "### Multiple Actors at Once\n",
- "Monarch scales to thousands of machines because of its ability to broadcast a single message to many actors at once, rather than send many point-to-point messages.\n",
- "\n",
- "Monarch expresses broadcasted communication by organizing actors into a `Mesh` - a multi-dimensional container with named dimensions. An example cluster may have dimensions `{\"hosts\": 32, \"gpus\": 8}`. To create a mesh of actors, we'll create a mesh of processes and spawn an actor on them:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4349a0c4-161c-4d68-9d97-202b650e344c",
- "metadata": {},
- "outputs": [],
- "source": [
- "from monarch.actor import ProcMesh, this_host\n",
- "\n",
- "procs: ProcMesh = this_host().spawn_procs(per_host={\"gpus\": 8})\n",
- "counters: Counter = procs.spawn(\"counters\", Counter, 0)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "202a8126-2e45-4074-a2ba-87e12d2f06dc",
- "metadata": {},
- "source": [
- "### Broadcasting Messages\n",
- "Now messages can be sent to all actors in the mesh:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "e1395ba9-c47a-4902-bea2-320bd1144fd2",
- "metadata": {},
- "outputs": [],
- "source": [
- "await counters.increment.call()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "35e2f5da-1840-49f5-b187-47be6d4b185d",
- "metadata": {},
- "source": [
- "Note that here, we use the `call()` adverb. You will see other adverbs in Monarch code as well:\n",
- "- `call_one()` - invoke a single actor and get its value (what we saw before)\n",
- "- `choose()` - randomly invoke a single actor and gets its value from within a mesh of actors\n",
- "- `call()` - invoke all actors in an actor mesh, and return its values as a `ValueMesh` \n",
- "- `broadcast()` - fire-and-forget all actors in an actor mesh\n",
- "- `stream()` - invoke all actors and return its values as an iterator\n",
- "\n",
- "There's much more to cover with Monarch, but these foundations provide the building blocks needed to understand how Forge creates its RL-specific services on top of this distributed actor system."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "34fbc2ba-620c-4170-8ad6-97e78ac3f0b0",
- "metadata": {},
- "outputs": [],
- "source": [
- "await procs.stop()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "924b1514-b10c-4a4e-bc11-222f3f2a2933",
- "metadata": {},
- "source": [
- "## Forge Services\n",
- "Forge introduces *Services* - a higher-level abstraction built on top of Monarch actors. Services handle all the operational complexity of managing distributed ActorMeshes: spawning actors across nodes, fault tolerance, load balancing, and intelligent routing.\n",
- "\n",
- "### Creating a Forge Service\n",
- "Creating a Forge service requires minimal changes to actors like we've created above. You replace your Actor base with a ForgeActor, and change how you spawn the actor:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "248530c1-3cbc-44b1-aa80-aab335280870",
- "metadata": {},
- "outputs": [],
- "source": [
- "from forge.controller import ForgeActor\n",
- "from forge.controller.service import ServiceConfig, spawn_service, shutdown_service\n",
- "from monarch.actor import endpoint\n",
- "\n",
- "\n",
- "class ForgeCounter(ForgeActor):\n",
- " def __init__(self, initial_value: int):\n",
- " self.value = initial_value\n",
- "\n",
- " @endpoint\n",
- " def increment(self) -> int:\n",
- " self.value += 1\n",
- " return self.value\n",
- "\n",
- " @endpoint\n",
- " def get_value(self) -> int:\n",
- " return self.value\n",
- "\n",
- " @endpoint\n",
- " async def reset(self):\n",
- " self.value = 0\n",
- "\n",
- " @endpoint\n",
- " def fail(self):\n",
- " raise RuntimeError(\"I was asked to fail\")\n",
- "\n",
- "\n",
- "counter_service = await spawn_service(\n",
- " ServiceConfig(procs_per_replica=1, num_replicas=4),\n",
- " ForgeCounter,\n",
- " initial_value=0)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "8f905101-9a69-4532-8e88-711c83ed1570",
- "metadata": {},
- "source": [
- "Here, we've created a simple \"Counter service\" with 4 replicas, each running on 1 process.\n",
- "\n",
- "### Service Adverbs: Operating at the Replica Level\n",
- "Services introduce new adverbs that operate at the replica level, not individual actors. Since replicas can be spawned across multiple processes, each replica is essentially an ActorMesh in Monarch terms:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "07a02047-f29a-47bf-a077-6bd5c2976cf1",
- "metadata": {},
- "outputs": [],
- "source": [
- "# choose() - routes to one replica (load balanced, and which may contain multiple actors)\n",
- "await counter_service.increment.choose()\n",
- "\n",
- "# call() - runs on ALL replicas\n",
- "results = await counter_service.increment.call()\n",
- "\n",
- "print(results)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "0a08b5a5-0180-4e0d-80b4-cc0620c510b4",
- "metadata": {},
- "source": [
- "Key distinction:\n",
- "- Monarch's `choose()` picks a single actor from an `ActorMesh`\n",
- "- Forge's `choose()` picks a single replica (which could be an entire `ActorMesh` of actors)\n",
- "\n",
- "This abstraction lets you think in terms of logical compute units (replicas) rather than individual processes."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "4ac9c7dc-42b5-4e10-a0ae-34916fc40360",
- "metadata": {},
- "source": [
- "### Load Balancing in Action\n",
- "Services handle load balancing:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "11181a5a-5692-4163-a96a-151d9b075454",
- "metadata": {},
- "outputs": [],
- "source": [
- "await counter_service.reset.call()\n",
- "print(\"Increment on different replicas:\")\n",
- "for i in range(8):\n",
- " result = await counter_service.increment.choose()\n",
- " print(f\"Call {i}: Counter value = {result}\")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "64c9334f-4db7-4cde-8ef2-d8e66a26d21b",
- "metadata": {},
- "source": [
- "Each replica maintains its own state, and requests are distributed evenly.\n",
- "\n",
- "### Sticky Session for Stateful Operations\n",
- "When you need to interact with the same replica consistently:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "453ec452-8a71-44ca-8afe-4393270356a1",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Use sticky sessions to stay on one replica\n",
- "async with counter_service.session():\n",
- " await counter_service.reset.choose()\n",
- " print(await counter_service.increment.choose())\n",
- " print(await counter_service.increment.choose())\n",
- " print(await counter_service.increment.choose())\n",
- " \n",
- " final_value = await counter_service.get_value.choose()\n",
- " print(f\"Final value on this replica: {final_value}\")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "b44eded4-d987-4620-9cf4-bbe3205102b1",
- "metadata": {},
- "source": [
- "Sticky sessions can be crucial for efficiency, i.e. whenever you need to maintain KV cache state across multiple turns.\n",
- "\n",
- "### Automatic Fault Tolerance\n",
- "Services automatically handle failures:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "48173946-c50f-4649-9064-ffaa4e116a7d",
- "metadata": {},
- "outputs": [],
- "source": [
- "# This will fail on one replica\n",
- "try:\n",
- " await counter_service.fail.choose()\n",
- "except ValueError:\n",
- " print(\"Replica failed, but service continues...\")\n",
- "\n",
- "# Subsequent calls automatically route around the failed replica\n",
- "result = await counter_service.increment.choose()\n",
- "print(f\"Still working: {result}\")\n",
- "\n",
- "# The failed replica will be automatically recovered"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "18e296b5-f1d5-4c3e-9f58-9842b78b1cc7",
- "metadata": {},
- "source": [
- "Behind the scenes: Forge marks unhealthy replicas, routes traffic away from them, and spawns replacements automatically.\n",
- "\n",
- "### Why This Matters for RL\n",
- "These service abstractions solve critical RL infrastructure challenges:\n",
- "\n",
- "1. Load balancing: Distribute rollouts across policy replicas efficiently\n",
- "2. Sticky sessions: Maintain state between rollouts and their associated replicas, i.e. KV cache consistency\n",
- "3. Fault tolerance: Keep training running even when individual nodes fail\n",
- "4. Operational simplicity: No infrastructure code in your RL algorithms\n",
- "\n",
- "### Performance: Control Plane vs Data Plane\n",
- "One important area we haven't covered yet is how Forge separates the **control plane** (service coordination) from the **data plane** (tensor transfers). You might reasonably wonder about performance implications if all data flows through TCP in a service-based architecture.\n",
- "\n",
- "We're actively developing **TorchStore** - our solution for high-performance tensor storage and retrieval over high-bandwidth interconnects like RDMA. This separation ensures that while Forge services handle coordination and routing, heavy tensor operations bypass the service layer entirely.\n",
- "\n",
- "*TorchStore will be covered in detail before our official release.*\n",
- "\n",
- "\n",
- "Next, we'll see how these building blocks enable elegant RL algorithm expression."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "44344bab-edcd-4851-98e4-59d8f9a4f3d8",
- "metadata": {},
- "outputs": [],
- "source": [
- "await shutdown_service(counter_service)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "e3a941c8-a744-4666-a1cb-67849b58e80a",
- "metadata": {},
- "source": [
- "## Forge-Native Services\n",
- "Now let's see the power of this abstraction in action. Forge provides service implementations of common RL components that constitute typical training workloads, for instance:\n",
- "- Policy: Responsible for generating trajectories and responses\n",
- "- Trainer: Responsible for updating policy weights\n",
- "- Reference Model: Responsible for computing reference logprobs to prevent policy drift\n",
- "- Reward: Responsible for evaluating trajectory quality\n",
- "- Dataset: Responsible for serving prompts and target answers\n",
- "- Advantage: Responsible for computing advantages from trajectories\n",
- "\n",
- "\n",
- "### Building a Synchronous RL Workflow\n",
- "Let's demonstrate by building a simple on-policy RL workflow. We'll start by spinning up multiple services using a small Qwen model:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "c622ed81-8fbc-4bd6-95f8-11a5df682711",
- "metadata": {},
- "outputs": [],
- "source": [
- "from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig\n",
- "from forge.actors.replay_buffer import ReplayBuffer\n",
- "from forge.controller.actor import ForgeActor\n",
- "from forge.controller.service import ServiceConfig, shutdown_service, spawn_service\n",
- "from forge.data.rewards import MathReward, ThinkingReward\n",
- "from apps.grpo.main import Trainer, RewardActor, ComputeAdvantages, RefModel, DatasetActor, Group, Episode\n",
- "\n",
- "\n",
- "model = \"Qwen/Qwen3-1.7B\"\n",
- "group_size = 1\n",
- "\n",
- "(\n",
- " dataloader,\n",
- " policy,\n",
- " trainer,\n",
- " replay_buffer,\n",
- " compute_advantages,\n",
- " ref_model,\n",
- " reward_actor,\n",
- ") = await asyncio.gather(\n",
- " spawn_service(\n",
- " ServiceConfig(procs_per_replica=1, num_replicas=1),\n",
- " DatasetActor,\n",
- " path=\"openai/gsm8k\",\n",
- " config_name=\"main\",\n",
- " split=\"train\",\n",
- " streaming=True,\n",
- " ),\n",
- " spawn_service(\n",
- " ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),\n",
- " Policy,\n",
- " config=PolicyConfig(\n",
- " worker_params=WorkerConfig(model=model),\n",
- " sampling_params=SamplingOverrides(\n",
- " num_samples=group_size, max_tokens=16\n",
- " ),\n",
- " ),\n",
- " ),\n",
- " spawn_service(\n",
- " ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),\n",
- " Trainer,\n",
- " learning_rate=1e-5,\n",
- " beta=0.1,\n",
- " model_name=model,\n",
- " ),\n",
- " spawn_service(\n",
- " ServiceConfig(procs_per_replica=1, num_replicas=1),\n",
- " ReplayBuffer,\n",
- " batch_size=2,\n",
- " max_policy_age=1,\n",
- " ),\n",
- " spawn_service(\n",
- " ServiceConfig(procs_per_replica=1, num_replicas=1),\n",
- " ComputeAdvantages,\n",
- " gamma=0.99,\n",
- " lambda_=0.95,\n",
- " ),\n",
- " spawn_service(\n",
- " ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),\n",
- " RefModel,\n",
- " model_name=model,\n",
- " ),\n",
- " spawn_service(\n",
- " ServiceConfig(procs_per_replica=1, num_replicas=1),\n",
- " RewardActor,\n",
- " reward_functions=[MathReward(), ThinkingReward()],\n",
- " ))\n",
- " "
- ]
- },
- {
- "cell_type": "markdown",
- "id": "0f6e86de-0ba8-4a79-84ec-78bf99a96616",
- "metadata": {},
- "source": [
- "What's happening here:\n",
- "- Each service is spawned independently with its own configuration\n",
- "- GPU services like the `policy`, `trainer`, and `ref_model` get GPU allocation\n",
- "- All services run concurrently and can be scaled independently\n",
- "- The same model is used for policy and reference, but they're separate services\n",
- "\n",
- "Notice what we're not doing:\n",
- "- Managing CUDA placement across services\n",
- "- Coordinating distributed training setup\n",
- "- Handling inter-service communication protocols\n",
- "- Writing fault tolerance and retry logic\n",
- "\n",
- "All of this is handled by our Service abstraction.\n",
- "\n",
- "Let's check that the services indeed work as expected:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "8334fdf0-0afa-48c2-bb42-faa46e8be36a",
- "metadata": {},
- "outputs": [],
- "source": [
- "prompt = \"What is 3 + 5?\"\n",
- "responses = await policy.generate.choose(prompt=prompt)\n",
- "\n",
- "print(responses)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "22dc3db0-2402-4344-ae2a-658ea46d9792",
- "metadata": {},
- "source": [
- "The response quality isn't great (it's only a 1.7B model!), but the service infrastructure is working perfectly.\n",
- "\n",
- "## Building the RL Training Loop\n",
- "### The Role of RL in Post-Training\n",
- "One way to interpret the role of RL in post-training is to align a base pre-trained model towards hard-to-define targets. The goal is \"sampling\" the right data that we think will best align the model.\n",
- "\n",
- "This is the role of \"rollouts\" - creating the dataset used to update our policy. Rather than training on a static dataset, RL dynamically generates training data by having the current policy interact with the environment.\n",
- "\n",
- "Let's build a step-by-step synchronous training loop to see how these services work together. The basic RL cycle is:\n",
- "\n",
- "1. Collect Experience: Get a prompt, generate a response, evaluate the reward\n",
- "2. Compute Rewards: Calculate how much better/worse each action was than expected\n",
- "3. Store Experience: Add the episode to our replay buffer\n",
- "4. Sample & Train: Get a batch of experiences and update the policy\n",
- "5. Repeat: Continue this cycle to improve the policy\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "298547e7-d185-48d1-a4b7-d122f760b707",
- "metadata": {},
- "outputs": [],
- "source": [
- "from apps.grpo.main import Episode, Group\n",
- "\n",
- "\n",
- "async def simple_rl_step():\n",
- " \"\"\"Execute one complete RL training step\"\"\"\n",
- " \n",
- " # ===== Generate a rollout =====\n",
- " sample = await dataloader.__next__.choose()\n",
- " prompt, target = sample[\"question\"], sample[\"answer\"]\n",
- " \n",
- " print(f\"Prompt: {prompt}\")\n",
- " print(f\"Target: {target}\")\n",
- " \n",
- " actions = await policy.generate.choose(prompt=prompt)\n",
- " print(f\"Policy response: {actions[0].text}\")\n",
- " \n",
- " ref_logprobs = await ref_model.forward.choose(actions[0].token_ids) \n",
- " reward = await reward_actor.evaluate_response.choose(\n",
- " prompt=prompt, \n",
- " response=actions[0].text, \n",
- " target=target\n",
- " )\n",
- " print(f\"Reward: {reward}\")\n",
- " \n",
- " episode = Episode(\n",
- " episode_id=0,\n",
- " prompt=prompt,\n",
- " target=target, \n",
- " policy_version=0,\n",
- " )\n",
- " \n",
- " episode.add_group(Group(\n",
- " response=actions[0].text,\n",
- " ref_logprobs=ref_logprobs,\n",
- " reward=reward,\n",
- " ))\n",
- " \n",
- " advantages = await compute_advantages.__call__.choose(episode.groups)\n",
- " episode.groups[0].advantage = advantages[0]\n",
- " print(f\"Advantage: {advantages[0]}\") \n",
- " await replay_buffer.add.choose(episode)\n",
- " print(\"Episode stored in replay buffer\")\n",
- " \n",
- " # ===== Train on the batch ===== \n",
- " batch = await replay_buffer.sample.choose(curr_policy_version=0)\n",
- " if batch is not None:\n",
- " print(\"Training on batch...\")\n",
- " training_result = await trainer.train_step.choose(batch)\n",
- " loss = training_result.get(\"loss\", 0.0)\n",
- " print(f\"Training loss: {loss}\")\n",
- " return loss\n",
- " else:\n",
- " print(\"Not enough data in buffer yet\")\n",
- " return None\n",
- "\n",
- "\n",
- "for step in range(10):\n",
- " print(f\"\\n--- RL Step {step + 1} ---\")\n",
- " loss = await simple_rl_step()\n",
- " if loss:\n",
- " print(f\"Step {step + 1} complete, loss: {loss:.4f}\")\n",
- " else:\n",
- " print(f\"Step {step + 1} complete, building buffer...\")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "d03fb780-54ef-445e-85ef-5708ceb8e5be",
- "metadata": {},
- "source": [
- "**Note**: The model responses aren't great (1.7B parameters + 16 token limit = not exactly o1!), but notice how clean the RL algorithm code is. The power of these abstractions is that you can focus on the algorithm logic while all the distributed coordination happens automatically behind the scenes.\n",
- "\n",
- "TODO - conclude this with trainer->inference weight sync and demonstrate how the response changes\n",
- "\n",
- "## Next Steps\n",
- "This simple example demonstrates the core concepts, but for a production-ready implementation, check out our full GRPO (Group Relative Policy Optimization) example at apps/grpo/main.py. It includes the complete async training loops, proper logging, model weight synchronization, and all the optimizations needed for large-scale RL training."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "ff8b2f2e-98f6-41cd-865e-14cfa23c510f",
- "metadata": {},
- "outputs": [],
- "source": [
- "await asyncio.gather(\n",
- " shutdown_service(policy),\n",
- " shutdown_service(trainer),\n",
- " shutdown_service(replay_buffer),\n",
- " shutdown_service(dataloader),\n",
- " shutdown_service(compute_advantages),\n",
- " shutdown_service(ref_model),\n",
- " shutdown_service(reward_actor))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "0be0a295-5d84-497d-a1b1-0bef53d9c753",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.18"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/docs/requirements.txt b/docs/requirements.txt
deleted file mode 100644
index 87c6f43da..000000000
--- a/docs/requirements.txt
+++ /dev/null
@@ -1,7 +0,0 @@
-sphinx==7.2.6
--e git+https://github.com/pytorch/pytorch_sphinx_theme.git@pytorch_sphinx_theme2#egg=pytorch_sphinx_theme2
-docutils>=0.18.1,<0.21
-sphinx-design==0.6.1
-sphinxcontrib-mermaid==1.0.0
-myst-parser #==0.18.1 # if want to contribute in markdown
-sphinx-sitemap==2.7.1
diff --git a/docs/source/_static/custom.css b/docs/source/_static/custom.css
new file mode 100644
index 000000000..5f2d0897b
--- /dev/null
+++ b/docs/source/_static/custom.css
@@ -0,0 +1,157 @@
+/* Center all Mermaid diagrams */
+.mermaid {
+ display: block;
+ margin-left: auto;
+ margin-right: auto;
+ text-align: center;
+}
+
+/* Center the pre element that contains mermaid diagrams */
+pre.mermaid {
+ display: flex;
+ justify-content: center;
+}
+
+/* Adjust Mermaid line colors based on theme */
+/* Light mode - darker lines for visibility on white background */
+html[data-theme="light"] .mermaid .edgePath .path,
+html[data-theme="light"] .mermaid .flowchart-link {
+ stroke: #555 !important;
+ stroke-width: 2px !important;
+}
+
+/* Light mode - darker arrow tips */
+html[data-theme="light"] .mermaid .arrowheadPath,
+html[data-theme="light"] .mermaid marker path {
+ fill: #555 !important;
+ stroke: #555 !important;
+}
+
+html[data-theme="dark"] .mermaid .arrowheadPath,
+html[data-theme="dark"] .mermaid marker path {
+ fill: #aaa !important;
+ stroke: #aaa !important;
+}
+
+/* Dark mode - lighter lines for visibility on dark background */
+html[data-theme="dark"] .mermaid .edgePath .path,
+html[data-theme="dark"] .mermaid .flowchart-link {
+ stroke: #aaa !important;
+ stroke-width: 2px !important;
+}
+
+/* Dark mode - lighter arrow tips */
+html[data-theme="dark"] .mermaid .arrowheadPath,
+html[data-theme="dark"] .mermaid marker path {
+ fill: #aaa !important;
+ stroke: #aaa !important;
+}
+
+/* Adjust edge labels background based on theme */
+html[data-theme="light"] .mermaid .edgeLabel {
+ background-color: #fff !important;
+}
+
+html[data-theme="dark"] .mermaid .edgeLabel {
+ background-color: #1a1a1a !important;
+ color: #fff !important;
+}
+
+/* Custom CSS for collapsible parameter lists */
+
+/* Hide parameters in signatures */
+.sig-param-hidden {
+ display: none !important;
+}
+
+/* Inline toggle button for signatures */
+.params-toggle-btn-inline {
+ display: inline;
+ padding: 0.2rem 0.5rem;
+ margin: 0 0.25rem;
+ background-color: var(--pst-color-background);
+ border: 1px solid var(--pst-color-border);
+ border-radius: 3px;
+ cursor: pointer;
+ font-size: 0.85em;
+ font-family: var(--pst-font-family-base);
+ color: var(--pst-color-primary);
+ transition: all 0.2s ease;
+ vertical-align: middle;
+}
+
+.params-toggle-btn-inline:hover {
+ background-color: var(--pst-color-background);
+ border-color: var(--pst-color-border);
+}
+
+.params-toggle-btn-inline:focus {
+ outline: none;
+}
+
+.toggle-icon {
+ display: inline-block;
+ font-size: 0.8em;
+ transition: transform 0.2s ease;
+}
+
+/* Wrapper for the button */
+.sig-params-wrapper {
+ display: inline;
+}
+
+/* Old styles for field-list collapsing (kept for backward compatibility) */
+.collapsible-params {
+ margin: 1rem 0;
+}
+
+.params-toggle-btn {
+ display: inline-block;
+ padding: 0.5rem 1rem;
+ margin-bottom: 0.5rem;
+ background-color: var(--pst-color-background);
+ border: 1px solid var(--pst-color-border);
+ border-radius: 4px;
+ cursor: pointer;
+ font-size: 0.9rem;
+ color: var(--pst-color-primary);
+ transition: all 0.3s ease;
+}
+
+.params-toggle-btn:hover {
+ background-color: var(--pst-color-background);
+ border-color: var(--pst-color-border);
+}
+
+.params-content {
+ max-height: 10000px;
+ overflow: hidden;
+ transition: max-height 0.5s ease, opacity 0.3s ease;
+ opacity: 1;
+}
+
+.params-content.collapsed {
+ max-height: 0;
+ opacity: 0;
+}
+
+/* Ensure the collapsed parameters look good */
+.params-content dl.field-list {
+ margin-top: 0;
+}
+
+.params-content > dt {
+ margin-top: 0.5rem;
+}
+
+.params-content > dt:first-child {
+ margin-top: 0;
+}
+
+/* Responsive adjustments */
+@media (max-width: 768px) {
+ .params-toggle-btn {
+ width: 100%;
+ text-align: left;
+ }
+}
diff --git a/docs/source/_static/custom.js b/docs/source/_static/custom.js
new file mode 100644
index 000000000..fa794ae89
--- /dev/null
+++ b/docs/source/_static/custom.js
@@ -0,0 +1,193 @@
+// Lightbox functionality for Mermaid diagrams
+document.addEventListener('DOMContentLoaded', function() {
+ // Create lightbox modal for Mermaid diagrams
+ const lightbox = document.createElement('div');
+ lightbox.id = 'mermaid-lightbox';
+ lightbox.style.cssText = `
+ display: none;
+ position: fixed;
+ z-index: 9999;
+ left: 0;
+ top: 0;
+ width: 100%;
+ height: 100%;
+ background-color: rgba(0,0,0,0.9);
+ cursor: zoom-out;
+ `;
+
+ const lightboxContent = document.createElement('div');
+ lightboxContent.style.cssText = `
+ position: absolute;
+ top: 50%;
+ left: 50%;
+ transform: translate(-50%, -50%);
+ max-width: 95%;
+ max-height: 95%;
+ overflow: auto;
+ `;
+
+ const closeBtn = document.createElement('span');
+ closeBtn.innerHTML = '×';
+ closeBtn.style.cssText = `
+ position: absolute;
+ top: 15px;
+ right: 35px;
+ color: #f1f1f1;
+ font-size: 40px;
+ font-weight: bold;
+ cursor: pointer;
+ z-index: 10000;
+ `;
+ closeBtn.onclick = function() {
+ lightbox.style.display = 'none';
+ };
+
+ lightbox.appendChild(closeBtn);
+ lightbox.appendChild(lightboxContent);
+ document.body.appendChild(lightbox);
+
+ // Click outside to close
+ lightbox.onclick = function(e) {
+ if (e.target === lightbox) {
+ lightbox.style.display = 'none';
+ }
+ };
+
+ // ESC key to close
+ document.addEventListener('keydown', function(e) {
+ if (e.key === 'Escape' && lightbox.style.display === 'block') {
+ lightbox.style.display = 'none';
+ }
+ });
+
+ // Make all Mermaid diagrams clickable
+ const mermaidDivs = document.querySelectorAll('.mermaid');
+ mermaidDivs.forEach(function(div) {
+ div.style.cursor = 'zoom-in';
+ div.title = 'Click to enlarge';
+
+ div.addEventListener('click', function() {
+ const clone = div.cloneNode(true);
+
+ // Style the cloned diagram to fill the lightbox
+ clone.style.cssText = `
+ cursor: default;
+ width: 90vw;
+ max-width: 1400px;
+ height: auto;
+ margin: auto;
+ `;
+
+ // Find and resize the SVG inside
+ const svg = clone.querySelector('svg');
+ if (svg) {
+ svg.style.cssText = `
+ width: 100% !important;
+ height: auto !important;
+ max-width: none !important;
+ max-height: 90vh !important;
+ `;
+ svg.removeAttribute('width');
+ svg.removeAttribute('height');
+ }
+
+ lightboxContent.innerHTML = '';
+ lightboxContent.appendChild(clone);
+ lightbox.style.display = 'block';
+ });
+ });
+});
+
+// Custom JavaScript to make long parameter lists in class signatures collapsible
+document.addEventListener('DOMContentLoaded', function() {
+ console.log('Collapsible parameters script loaded');
+
+ // Find all class/function signatures
+ const signatures = document.querySelectorAll('dl.py.class > dt, dl.py.function > dt, dl.py.method > dt');
+
+ signatures.forEach(function(signature) {
+ // Find all parameter elements in the signature
+ const params = signature.querySelectorAll('em.sig-param, .sig-param');
+
+ console.log(`Found signature with ${params.length} parameters`);
+
+ // Only make it collapsible if there are more than 10 parameters
+ if (params.length > 10) {
+ console.log('Creating collapsible structure for signature with', params.length, 'parameters');
+
+ const visibleCount = 5;
+ const hiddenCount = params.length - visibleCount;
+
+ // Create a wrapper div for the toggle button
+ const wrapper = document.createElement('span');
+ wrapper.className = 'sig-params-wrapper';
+ wrapper.style.display = 'inline';
+
+ // Create toggle button
+ const toggleBtn = document.createElement('button');
+ toggleBtn.className = 'params-toggle-btn-inline';
+ toggleBtn.innerHTML = ` Show More`;
+ toggleBtn.setAttribute('aria-expanded', 'false');
+ toggleBtn.title = `Show ${hiddenCount} more parameters`;
+
+ // Collect all nodes to hide (params and text nodes between them)
+ const nodesToHide = [];
+
+ // Hide parameters after the first 3
+ let insertedButton = false;
+ params.forEach(function(param, index) {
+ if (index >= visibleCount) {
+ // Add 'hidden' class to hide the parameter
+ param.classList.add('sig-param-hidden');
+ nodesToHide.push(param);
+
+ // Also hide the text node (comma/space) that follows this parameter
+ let nextNode = param.nextSibling;
+ while (nextNode && nextNode.nodeType === Node.TEXT_NODE) {
+ const textSpan = document.createElement('span');
+ textSpan.className = 'sig-param-hidden';
+ textSpan.textContent = nextNode.textContent;
+ nextNode.parentNode.replaceChild(textSpan, nextNode);
+ nodesToHide.push(textSpan);
+ break;
+ }
+
+ // Insert the toggle button before the first hidden parameter
+ if (!insertedButton) {
+ param.parentNode.insertBefore(wrapper, param);
+ wrapper.appendChild(toggleBtn);
+ insertedButton = true;
+ }
+ }
+ });
+
+ // Add click handler to toggle
+ toggleBtn.addEventListener('click', function(e) {
+ e.preventDefault();
+ e.stopPropagation();
+
+ const isExpanded = toggleBtn.getAttribute('aria-expanded') === 'true';
+
+ if (isExpanded) {
+ // Collapse: hide parameters again
+ nodesToHide.forEach(function(node) {
+ node.classList.add('sig-param-hidden');
+ });
+ toggleBtn.setAttribute('aria-expanded', 'false');
+ toggleBtn.innerHTML = ` Show More`;
+ toggleBtn.title = `Show ${hiddenCount} more parameters`;
+ } else {
+ // Expand: show all parameters
+ nodesToHide.forEach(function(node) {
+ node.classList.remove('sig-param-hidden');
+ });
+ toggleBtn.setAttribute('aria-expanded', 'true');
+ toggleBtn.innerHTML = ` Hide`;
+ toggleBtn.title = `Hide ${hiddenCount} parameters`;
+ }
+ });
+
+ console.log('Collapsible structure created successfully');
+ }
+ });
+});
diff --git a/docs/source/_static/logo-icon.svg b/docs/source/_static/logo-icon.svg
new file mode 100644
index 000000000..9dcafc39a
--- /dev/null
+++ b/docs/source/_static/logo-icon.svg
@@ -0,0 +1,12 @@
+
+
+
+
diff --git a/docs/source/api.md b/docs/source/api.md
index 5ed009c4c..5c846de91 100644
--- a/docs/source/api.md
+++ b/docs/source/api.md
@@ -1,35 +1,35 @@
# API Reference
-This section provides comprehensive API documentation for TorchForge modules and classes.
+This section provides comprehensive API documentation for TorchForge.
-TorchForge is organized into several key modules, each providing specialized functionality for post-training generative AI models:
+## Overview
-## Module Overview
+TorchForge is a PyTorch native platform for post-training generative AI models,
+designed to streamline reinforcement learning workflows for large language
+models. The platform leverages PyTorch's distributed computing capabilities
+and is built on top of [Monarch](https://meta-pytorch.org/monarch/),
+making extensive use of actors for distributed computation and fault tolerance.
-**Core Components**
-- [Interfaces & Types](api_core.md) - Core interfaces and type definitions
-- [Actors](api_actors.md) - Model training and inference components
-- [Controller](api_controller.md) - Distributed training orchestration and resource management
+Key Features of TorchForge include:
-**Data Management**
-- [Data](api_data.md) - Data handling utilities, datasets, and data models
+- **Actor-Based Architecture**: TorchForge uses an actor-based system for distributed training, providing excellent scalability and fault tolerance.
+- **PyTorch Native**: Built natively on PyTorch, ensuring seamless integration with existing PyTorch workflows.
+- **Post-Training Focus**: Specifically designed for post-training techniques like RLVR, SFT, and other alignment methods.
+- **Distributed by Design**: Supports multi-GPU and multi-node training out of the box.
-**Training Components**
-- [Losses](api_losses.md) - Loss functions for reinforcement learning and supervised fine-tuning
-- [Environments](api_envs.md) - Training and inference environments
-**Tools & Utilities**
-- [Utilities](api_util.md) - General utility functions and helpers
+For most use cases, you'll interact with the high-level service
+interfaces, which handle the complexity of actor coordination and
+distributed training automatically.
-```{toctree}
-:maxdepth: 2
-:hidden:
+For advanced users who need fine-grained control, the individual actor
+APIs provide direct access to the underlying distributed components.
-api_core
+```{toctree}
+:maxdepth: 1
api_actors
-api_data
-api_losses
-api_envs
-api_controller
-api_util
+api_service
+api_generator
+api_model
+api_trainer
```
diff --git a/docs/source/api_actors.md b/docs/source/api_actors.md
index 6ef5f1ff8..e53101b27 100644
--- a/docs/source/api_actors.md
+++ b/docs/source/api_actors.md
@@ -1,19 +1,20 @@
-# Actors
-
-The actors module contains the core components for model training and inference in TorchForge. This includes policy actors, reference models, replay buffers, and trainers.
-
-## Policy Actor
-
-The policy actor is responsible for model inference and policy interactions during training.
-
-## Reference Model
-
-The reference model provides baseline comparisons for reinforcement learning algorithms.
-
-## Replay Buffer
-
-The replay buffer manages storage and sampling of training experiences.
-
-## Trainer
-
-The trainer orchestrates the training process and implements training algorithms.
+# ForgeActor
+
+```{eval-rst}
+.. currentmodule:: forge.actors
+```
+
+The actors module contains the core components for model training
+and inference in TorchForge. These pre-built actors provide essential
+functionality for reinforcement learning workflows and can be used
+as building blocks for complex distributed training systems.
+
+```{eval-rst}
+.. currentmodule:: forge.controller.actor
+
+.. autoclass:: ForgeActor
+ :members:
+ :undoc-members:
+ :show-inheritance:
+ :exclude-members: logger, setup, set_env, __init__, as_service
+```
diff --git a/docs/source/api_controller.md b/docs/source/api_controller.md
deleted file mode 100644
index e9bedda74..000000000
--- a/docs/source/api_controller.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# Controller
-
-Distributed training orchestration and resource management components for TorchForge.
diff --git a/docs/source/api_core.md b/docs/source/api_core.md
deleted file mode 100644
index 75b3e9ae5..000000000
--- a/docs/source/api_core.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# Core Interfaces
-
-This section covers the fundamental interfaces and type definitions that form the foundation of TorchForge.
diff --git a/docs/source/api_data.md b/docs/source/api_data.md
deleted file mode 100644
index cbc1cfc53..000000000
--- a/docs/source/api_data.md
+++ /dev/null
@@ -1,16 +0,0 @@
-# Data Management
-
-Comprehensive data handling utilities for training and
-inference, including datasets, data models, and various
-data processing utilities.
-
-## Prompt
-
-Data model for input prompts and contexts.
-
-```{eval-rst}
-.. automodule:: forge.data_models.prompt
- :members:
- :undoc-members:
- :show-inheritance:
-```
diff --git a/docs/source/api_envs.md b/docs/source/api_envs.md
deleted file mode 100644
index 88e9d1cea..000000000
--- a/docs/source/api_envs.md
+++ /dev/null
@@ -1,8 +0,0 @@
-# Environments
-
-Training and inference environments for TorchForge models.
-
-
-## Chat Environment
-
-Chat-based environment for conversational AI training and inference.
diff --git a/docs/source/api_generator.md b/docs/source/api_generator.md
new file mode 100644
index 000000000..31b67c03c
--- /dev/null
+++ b/docs/source/api_generator.md
@@ -0,0 +1,27 @@
+# Generator
+
+```{eval-rst}
+.. currentmodule:: forge.actors.generator
+```
+
+The Generator (Policy) is the core inference engine in TorchForge,
+built on top of [vLLM](https://docs.vllm.ai/en/latest/).
+It manages model serving, text generation, and weight updates for reinforcement learning workflows.
+
+## Generator
+
+```{eval-rst}
+.. autoclass:: Generator
+ :members: generate, update_weights, get_version, stop
+ :exclude-members: __init__, launch
+ :no-inherited-members:
+```
+
+## GeneratorWorker
+
+```{eval-rst}
+.. autoclass:: GeneratorWorker
+ :members: execute_model, update, setup_kv_cache
+ :show-inheritance:
+ :exclude-members: __init__
+```
diff --git a/docs/source/api_losses.md b/docs/source/api_losses.md
deleted file mode 100644
index 097b83394..000000000
--- a/docs/source/api_losses.md
+++ /dev/null
@@ -1,11 +0,0 @@
-# Losses
-
-Loss functions for reinforcement learning and supervised fine-tuning in TorchForge.
-
-## GRPO Loss
-
-Generalized Reward Policy Optimization (GRPO) loss implementation for reinforcement learning.
-
-## Reinforce Loss
-
-Reinforce algorithm loss implementation for policy gradient methods.
diff --git a/docs/source/api_model.md b/docs/source/api_model.md
new file mode 100644
index 000000000..94e51478e
--- /dev/null
+++ b/docs/source/api_model.md
@@ -0,0 +1,29 @@
+# Model
+
+```{eval-rst}
+.. currentmodule:: forge.actors.reference_model
+```
+
+The {class}`forge.actors.reference_model.ReferenceModel` provides a frozen
+copy of the policy model used for computing advantages in reinforcement
+learning. It performs inference on input sequences and returns logits or
+log probabilities for computing KL divergence and other RL metrics.
+
+## ReferenceModel
+
+```{eval-rst}
+.. autoclass:: forge.actors.reference_model.ReferenceModel
+ :members:
+ :undoc-members:
+ :show-inheritance:
+```
+
+The ReferenceModel uses a subset of TorchTitan's configuration system:
+
+- **model**: Model architecture settings (Model dataclass)
+- **parallelism**: Parallelism configuration for distributed inference (Parallelism dataclass)
+- **checkpoint**: Checkpoint loading settings (Checkpoint dataclass)
+- **compile**: Model compilation settings (Compile dataclass)
+- **training**: Training configuration for dtype and other settings (Training dataclass)
+
+For detailed configuration options, refer to the [TorchTitan documentation](https://github.com/pytorch/torchtitan).
diff --git a/docs/source/api_service.md b/docs/source/api_service.md
new file mode 100644
index 000000000..df2bf3dc8
--- /dev/null
+++ b/docs/source/api_service.md
@@ -0,0 +1,12 @@
+# Service
+
+```{eval-rst}
+.. currentmodule:: forge.controller.service.service
+```
+
+```{eval-rst}
+.. autoclass:: Service
+
+ :members: call_all, start_session, get_metrics, get_metrics_summary, terminate_session, stop
+ :show-inheritance:
+```
diff --git a/docs/source/api_trainer.md b/docs/source/api_trainer.md
new file mode 100644
index 000000000..6e66e5418
--- /dev/null
+++ b/docs/source/api_trainer.md
@@ -0,0 +1,68 @@
+# Trainer
+
+```{eval-rst}
+.. currentmodule:: forge.actors.trainer
+```
+
+The Trainer manages model training in TorchForge, built on top of TorchTitan.
+It handles forward/backward passes, weight updates, and checkpoint management for reinforcement learning workflows.
+
+## TitanTrainer
+
+```{eval-rst}
+.. autoclass:: TitanTrainer
+ :members: train_step, push_weights, cleanup
+ :exclude-members: __init__
+```
+
+## Configuration
+
+The TitanTrainer uses TorchTitan's configuration system with the following components:
+
+### Job Configuration
+
+```{eval-rst}
+.. autoclass:: torchtitan.config.job_config.Job
+ :members:
+ :undoc-members:
+```
+
+### Model Configuration
+
+```{eval-rst}
+.. autoclass:: torchtitan.config.job_config.Model
+ :members:
+ :undoc-members:
+```
+
+### Optimizer Configuration
+
+```{eval-rst}
+.. autoclass:: torchtitan.config.job_config.Optimizer
+ :members:
+ :undoc-members:
+```
+
+### Training Configuration
+
+```{eval-rst}
+.. autoclass:: torchtitan.config.job_config.Training
+ :members:
+ :undoc-members:
+```
+
+### Parallelism Configuration
+
+```{eval-rst}
+.. autoclass:: torchtitan.config.job_config.Parallelism
+ :members:
+ :undoc-members:
+```
+
+### Checkpoint Configuration
+
+```{eval-rst}
+.. autoclass:: torchtitan.config.job_config.Checkpoint
+ :members:
+ :undoc-members:
+```
diff --git a/docs/source/api_util.md b/docs/source/api_util.md
deleted file mode 100644
index f15e03b76..000000000
--- a/docs/source/api_util.md
+++ /dev/null
@@ -1,25 +0,0 @@
-# Utilities
-
-General utility functions and helpers used throughout TorchForge.
-
-## Distributed Computing
-
-Utilities for distributed training and communication.
-
-```{eval-rst}
-.. automodule:: forge.util.distributed
- :members:
- :undoc-members:
- :show-inheritance:
-```
-
-## Logging
-
-Logging configuration and utilities.
-
-```{eval-rst}
-.. automodule:: forge.util.logging
- :members:
- :undoc-members:
- :show-inheritance:
-```
diff --git a/docs/source/concepts.md b/docs/source/concepts.md
deleted file mode 100644
index 075d1ef7f..000000000
--- a/docs/source/concepts.md
+++ /dev/null
@@ -1,4 +0,0 @@
-# Concepts
-
-This guide covers the fundamental concepts and architecture behind TorchForge,
-helping you understand how the system works and how to effectively use its components.
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 69b3c20b5..951f9c1cd 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -17,10 +17,33 @@
import pytorch_sphinx_theme2
# Add the source directory to Python path so modules can be imported
-sys.path.insert(0, os.path.abspath("../../src"))
+sys.path.insert(0, os.path.abspath("../../src/forge"))
+
+
+# Determine the version path for deployment
+def get_version_path():
+ """Get the version path based on environment variables or git context."""
+ # Check if we're in CI/CD and get the target folder
+ github_ref = os.environ.get("GITHUB_REF", "")
+
+ # Convert refs/tags/v1.12.0rc3 into 1.12.
+ # Matches the logic in .github/workflows/docs.yml
+ if github_ref.startswith("refs/tags/v"):
+ import re
+
+ match = re.match(r"^refs/tags/v([0-9]+\.[0-9]+)\..*", github_ref)
+ if match:
+ return match.group(1) + "/"
+
+ # Default to main for main branch or local development
+ return "main/"
+
+
+# Set base URL based on deployment context
+version_path = get_version_path()
project = "torchforge"
-copyright = "2025, PyTorch Contributors"
+copyright = ""
author = "PyTorch Contributors"
release = "0.1"
@@ -35,12 +58,18 @@
"myst_parser",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
+ "sphinx_autodoc_typehints",
"sphinx.ext.napoleon",
"sphinx.ext.intersphinx",
"sphinx.ext.viewcode",
+ "sphinx_gallery.gen_gallery",
]
-html_baseurl = "https://meta-pytorch.org/forge/" # needed for sphinx-sitemap
+html_favicon = "_static/logo-icon.svg"
+
+html_baseurl = (
+ f"https://meta-pytorch.org/torchforge/{version_path}" # needed for sphinx-sitemap
+)
sitemap_locales = [None]
sitemap_excludes = [
"search.html",
@@ -48,11 +77,24 @@
]
sitemap_url_scheme = "{link}"
+# Ensure static files use relative paths
+html_static_path = ["_static"]
+
templates_path = [
"_templates",
os.path.join(os.path.dirname(pytorch_sphinx_theme2.__file__), "templates"),
]
-exclude_patterns = []
+
+exclude_patterns = [
+ "tutorials/index.rst",
+ "tutorials/template_tutorial.rst",
+ "tutorials/**/index.rst",
+ "tutorial_sources/**/*.md", # Exclude all markdown files from tutorial_sources
+ "tutorial_sources/**/*.MD", # Also exclude uppercase .MD files
+]
+html_static_path = ["_static"]
+html_css_files = ["custom.css"]
+html_js_files = ["custom.js"]
sys.path.insert(0, os.path.abspath("."))
sys.path.insert(0, os.path.abspath("../../src"))
@@ -80,7 +122,7 @@
},
{
"name": "GitHub",
- "url": "https://github.com/meta-pytorch/forge",
+ "url": "https://github.com/meta-pytorch/torchforge",
"icon": "fa-brands fa-github",
},
{
@@ -90,14 +132,16 @@
},
{
"name": "PyPi",
- "url": "https://pypi.org/project/forge/",
+ "url": "https://pypi.org/project/torchforge/",
"icon": "fa-brands fa-python",
},
],
"use_edit_page_button": True,
"navbar_center": "navbar-nav",
- "canonical_url": "https://meta-pytorch.org/forge/",
+ "canonical_url": "https://meta-pytorch.org/torchforge/",
"header_links_before_dropdown": 7,
+ "show_toc_level": 2,
+ "navigation_depth": 3,
}
theme_variables = pytorch_sphinx_theme2.get_theme_variables()
@@ -107,26 +151,150 @@
"display_github": True,
"github_url": "https://github.com",
"github_user": "meta-pytorch",
- "github_repo": "forge",
+ "github_repo": "torchforge",
"feedback_url": "https://github.com/meta-pytorch/forge",
+ "colab_branch": "gh-pages",
"github_version": "main",
"doc_path": "docs/source",
+ "has_sphinx_gallery": True, # Enable tutorial call-to-action links
}
+# For tutorial repository configuration
+# Note: github_user and github_repo are combined in the template as "{{ github_user }}/{{ github_repo }}"
+# So we keep github_user = "meta-pytorch" and github_repo = "forge" already set above
+# and only need to ensure the branch settings are correct
+tutorial_repo_config = {
+ "github_version": "main", # This maps to github_branch in the template
+ "colab_branch": "gh-pages",
+}
+html_context.update(tutorial_repo_config)
+
myst_enable_extensions = [
"colon_fence",
"deflist",
"html_image",
+ "substitution",
]
+# Configure MyST parser to treat mermaid code blocks as mermaid directives
+myst_fence_as_directive = ["mermaid"]
+
+# Disable D3 zoom (we'll use lightbox instead)
+mermaid_d3_zoom = False
+
+# Global Mermaid theme configuration - applies to all diagrams
+mermaid_init_js = """
+import mermaid from 'https://cdn.jsdelivr.net/npm/mermaid@11.2.0/dist/mermaid.esm.min.mjs';
+mermaid.initialize({
+ startOnLoad: false,
+ theme: 'base',
+ themeVariables: {
+ primaryColor: '#4CAF50',
+ primaryTextColor: '#000',
+ primaryBorderColor: '#fff',
+ lineColor: '#555',
+ secondaryColor: '#FF9800',
+ tertiaryColor: '#ffffde'
+ },
+ flowchart: {
+ curve: 'basis'
+ },
+ themeCSS: '.edgePath .path { stroke-width: 4px; stroke: #555; }'
+});
+"""
+
autodoc_default_options = {
"members": True,
- "member-order": "bysource",
- "special-members": "__init__",
"undoc-members": True,
- "exclude-members": "__weakref__",
+ "private-members": False,
+ "inherited-members": False,
+}
+
+# Autodoc configuration for cleaner signatures
+autodoc_preserve_defaults = True # Preserves default values without expansion
+autodoc_typehints = "description" # Move type hints to description instead of signature
+autodoc_typehints_description_target = (
+ "documented_params" # Only add types to documented params
+)
+
+# Disable docstring inheritance
+autodoc_inherit_docstrings = False
+autodoc_typehints = "none"
+
+
+# Removed suppress_warnings to make the build stricter
+# All warnings will now be treated as errors when -W is passed to sphinx-build
+
+# Be strict about references to catch broken links and references
+nitpicky = False
+
+# Napoleon settings for Google-style docstrings (from torchtitan and other dependencies)
+napoleon_google_docstring = True
+napoleon_numpy_docstring = True
+napoleon_use_param = True
+napoleon_use_rtype = True
+napoleon_use_ivar = True
+
+
+# -- Sphinx Gallery configuration -------------------------------------------
+sphinx_gallery_conf = {
+ "examples_dirs": "tutorial_sources", # Path to examples directory
+ "gallery_dirs": "tutorials", # Path to generate gallery
+ "filename_pattern": r".*\.py$", # Only process .py files, not .md files
+ "download_all_examples": False,
+ "first_notebook_cell": "%matplotlib inline",
+ "plot_gallery": "True",
+ "promote_jupyter_magic": True,
+ "backreferences_dir": None,
+ "show_signature": False,
+ "write_computation_times": False,
+ "ignore_pattern": r".*\.md$|.*\.MD$", # Explicitly ignore markdown files
}
-# Autosummary settings
-autosummary_generate = True
-autosummary_imported_members = True
+
+def clean_docstring_indentation(app, what, name, obj, options, lines):
+ if name and name.startswith("torchtitan."):
+ lines[:] = [line.lstrip() for line in lines]
+ if lines and lines[-1].strip():
+ lines.append("")
+
+
+def copy_markdown_tutorials(app):
+ """Copy markdown files from tutorial_sources to tutorials directory.
+
+ This runs after the builder is initialized but before sphinx-gallery processes files,
+ ensuring markdown files are available alongside generated .py tutorials.
+ """
+ import shutil
+ from pathlib import Path
+
+ source_dir = Path(app.srcdir) / "tutorial_sources"
+ target_dir = Path(app.srcdir) / "tutorials"
+
+ # Ensure target directory exists
+ target_dir.mkdir(parents=True, exist_ok=True)
+
+ # Walk through tutorial_sources and copy all .md files
+ for md_file in source_dir.rglob("*.md"):
+ # Skip README files
+ if md_file.name.lower() in ["readme.md", "readme.txt"]:
+ continue
+
+ # Calculate relative path from tutorial_sources
+ rel_path = md_file.relative_to(source_dir)
+
+ # Create target path in tutorials directory
+ target_path = target_dir / rel_path
+ target_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # Copy the file
+ shutil.copy2(md_file, target_path)
+ print(
+ f"[TorchForge Docs] Copied {md_file.name} to {target_path.relative_to(app.srcdir)}"
+ )
+
+
+def setup(app):
+ app.connect("autodoc-process-docstring", clean_docstring_indentation)
+ # Use builder-inited to ensure it runs before source files are read
+ app.connect("builder-inited", copy_markdown_tutorials)
diff --git a/docs/source/faq.md b/docs/source/faq.md
deleted file mode 100644
index d3c027866..000000000
--- a/docs/source/faq.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# FAQ
-
-This FAQ covers common questions and issues encountered when using TorchForge.
diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md
index 57e1b63c8..ce2928b77 100644
--- a/docs/source/getting_started.md
+++ b/docs/source/getting_started.md
@@ -1,9 +1,282 @@
-# Get Started
+# Getting Started
-Welcome to TorchForge! This guide will help you get up and running with TorchForge, a PyTorch-native platform specifically designed for post-training generative AI models.
+This guide will walk you through installing TorchForge, understanding its dependencies, verifying your setup, and running your first training job.
-TorchForge specializes in post-training techniques for large language models, including:
+## System Requirements
-- **Supervised Fine-Tuning (SFT)**: Adapt pre-trained models to specific tasks using labeled data
-- **Generalized Reward Policy Optimization (GRPO)**: Advanced reinforcement learning for model alignment
-- **Multi-GPU Distributed Training**: Efficient scaling across multiple GPUs and nodes
+Before installing TorchForge, ensure your system meets the following requirements.
+
+| Component | Requirement | Notes |
+|-----------|-------------|-------|
+| **Operating System** | Linux (Fedora/Ubuntu/Debian) | MacOS and Windows not currently supported |
+| **Python** | 3.10 or higher | Python 3.11 recommended |
+| **GPU** | NVIDIA with CUDA support | AMD GPUs not currently supported |
+| **Minimum GPUs** | 2+ for SFT, 3+ for GRPO | More GPUs enable larger models |
+| **CUDA** | 12.8 | Required for GPU training |
+| **RAM** | 32GB+ recommended | Depends on model size |
+| **Disk Space** | 50GB+ free | For models, datasets, and checkpoints |
+| **PyTorch** | Nightly build | Latest distributed features (DTensor, FSDP) |
+| **Monarch** | Pre-packaged wheel | Distributed orchestration and actor system |
+| **vLLM** | v0.10.0+ | Fast inference with PagedAttention |
+| **TorchTitan** | Latest | Production training infrastructure |
+
+
+## Prerequisites
+
+- **Conda or Miniconda**: For environment management
+ - Download from [conda.io](https://docs.conda.io/en/latest/miniconda.html)
+
+- **GitHub CLI (gh)**: Required for downloading pre-packaged dependencies
+ - Install instructions: [github.com/cli/cli#installation](https://github.com/cli/cli#installation)
+ - After installing, authenticate with: `gh auth login`
+ - You can use either HTTPS or SSH as the authentication protocol
+
+- **Git**: For cloning the repository
+ - Usually pre-installed on Linux systems
+ - Verify with: `git --version`
+
+
+**Installation note:** The installation script provides pre-built wheels with PyTorch nightly already included.
+
+## Installation
+
+TorchForge uses pre-packaged wheels for all dependencies, making installation faster and more reliable.
+
+1. **Clone the Repository**
+
+ ```bash
+ git clone https://github.com/meta-pytorch/forge.git
+ cd forge
+ ```
+
+2. **Create Conda Environment**
+
+ ```bash
+ conda create -n forge python=3.10
+ conda activate forge
+ ```
+
+3. **Run Installation Script**
+
+ ```bash
+ ./scripts/install.sh
+ ```
+
+ The installation script will:
+ - Install system dependencies using DNF (or your package manager)
+ - Download pre-built wheels for PyTorch nightly, Monarch, vLLM, and TorchTitan
+ - Install TorchForge and all Python dependencies
+ - Configure the environment for GPU training
+
+ ```{tip}
+ **Using sudo instead of conda**: If you prefer installing system packages directly rather than through conda, use:
+ `./scripts/install.sh --use-sudo`
+ ```
+
+ ```{warning}
+ When adding packages to `pyproject.toml`, use `uv sync --inexact` to avoid removing Monarch and vLLM.
+ ```
+
+## Verifying Your Setup
+
+After installation, verify that all components are working correctly:
+
+1. **Check GPU Availability**
+
+ ```bash
+ python -c "import torch; print(f'GPUs available: {torch.cuda.device_count()}')"
+ ```
+
+ Expected output: `GPUs available: 2` (or more)
+
+2. **Check CUDA Version**
+
+ ```bash
+ python -c "import torch; print(f'CUDA version: {torch.version.cuda}')"
+ ```
+
+ Expected output: `CUDA version: 12.8`
+3. **Check All Dependencies**
+
+ ```bash
+ # Check core components
+ python -c "import torch, forge, monarch, vllm; print('All imports successful')"
+
+ # Check specific versions
+ python -c "
+ import torch
+ import forge
+ import vllm
+
+ print(f'PyTorch: {torch.__version__}')
+ print(f'TorchForge: {forge.__version__}')
+ print(f'vLLM: {vllm.__version__}')
+ print(f'CUDA: {torch.version.cuda}')
+ print(f'GPUs: {torch.cuda.device_count()}')
+ "
+ ```
+
+4. **Verify Monarch**
+
+ ```bash
+ python -c "
+ from monarch.actor import Actor, this_host
+
+ # Test basic Monarch functionality
+ procs = this_host().spawn_procs({'gpus': 1})
+ procs.initialized.get()
+ print('Monarch: Process spawning works')
+ "
+ ```
+
+## Quick Start Examples
+
+Now that TorchForge is installed, let's run some training examples.
+
+Here's what training looks like with TorchForge:
+
+```bash
+# Install dependencies
+conda create -n forge python=3.10
+conda activate forge
+git clone https://github.com/meta-pytorch/forge
+cd forge
+./scripts/install.sh
+
+# Download a model
+hf download meta-llama/Meta-Llama-3.1-8B-Instruct --local-dir /tmp/Meta-Llama-3.1-8B-Instruct --exclude "original/consolidated.00.pth"
+
+# Run SFT training (requires 2+ GPUs)
+uv run forge run --nproc_per_node 2 \
+ apps/sft/main.py --config apps/sft/llama3_8b.yaml
+
+# Run GRPO training (requires 3+ GPUs)
+python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
+```
+
+### Example 1: Supervised Fine-Tuning (SFT)
+
+Fine-tune Llama 3 8B on your data. **Requires: 2+ GPUs**
+
+1. **Access the Model**
+
+ ```{note}
+ Model downloads are no longer required, but Hugging Face authentication is required to access the models.
+
+ Run `huggingface-cli login` first if you haven't already.
+ ```
+
+2. **Run Training**
+
+ ```bash
+ python -m apps.sft.main --config apps/sft/llama3_8b.yaml
+ ```
+
+ **What's Happening:**
+ - `--nproc_per_node 2`: Use 2 GPUs for training
+ - `apps/sft/main.py`: SFT training script
+ - `--config apps/sft/llama3_8b.yaml`: Configuration file with hyperparameters
+ - **TorchTitan** handles model sharding across the 2 GPUs
+ - **Monarch** coordinates the distributed training
+
+### Example 2: GRPO Training
+
+Train a model using reinforcement learning with GRPO. **Requires: 3+ GPUs**
+
+```bash
+python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
+```
+
+**What's Happening:**
+- GPU 0: Trainer model (being trained, powered by TorchTitan)
+- GPU 1: Reference model (frozen baseline, powered by TorchTitan)
+- GPU 2: Policy model (scoring outputs, powered by vLLM)
+- **Monarch** orchestrates all three components
+- **TorchStore** handles weight synchronization from training to inference
+
+## Understanding Configuration Files
+
+TorchForge uses YAML configuration files to manage training parameters. Let's examine a typical config:
+
+```yaml
+# Example: apps/sft/llama3_8b.yaml
+model:
+ name: meta-llama/Meta-Llama-3.1-8B-Instruct
+ path: /tmp/Meta-Llama-3.1-8B-Instruct
+
+training:
+ batch_size: 4
+ learning_rate: 1e-5
+ num_epochs: 10
+ gradient_accumulation_steps: 4
+
+distributed:
+ strategy: fsdp # Managed by TorchTitan
+ precision: bf16
+
+checkpointing:
+ save_interval: 1000
+ output_dir: /tmp/checkpoints
+```
+
+**Key Sections:**
+- **model**: Model path and settings
+- **training**: Hyperparameters like batch size and learning rate
+- **distributed**: Multi-GPU strategy (FSDP, tensor parallel, etc.) handled by TorchTitan
+- **checkpointing**: Where and when to save model checkpoints
+
+## Next Steps
+
+Now that you have TorchForge installed and verified:
+
+1. **Explore Examples**: Check the `apps/` directory for more training examples
+2. **Read Tutorials**: Follow {doc}`tutorials` for step-by-step guides
+3. **API Documentation**: Explore {doc}`api` for detailed API reference
+
+## Getting Help
+
+If you encounter issues:
+
+1. **Search Issues**: Look through [GitHub Issues](https://github.com/meta-pytorch/forge/issues)
+2. **File a Bug Report**: Create a new issue with:
+ - Your system configuration (output of diagnostic command below)
+ - Full error message
+ - Steps to reproduce
+ - Expected vs actual behavior
+
+**Diagnostic command:**
+```bash
+python -c "
+import torch
+import forge
+
+try:
+ import monarch
+ monarch_status = 'OK'
+except Exception as e:
+ monarch_status = str(e)
+
+try:
+ import vllm
+ vllm_version = vllm.__version__
+except Exception as e:
+ vllm_version = str(e)
+
+print(f'PyTorch: {torch.__version__}')
+print(f'TorchForge: {forge.__version__}')
+print(f'Monarch: {monarch_status}')
+print(f'vLLM: {vllm_version}')
+print(f'CUDA: {torch.version.cuda}')
+print(f'GPUs: {torch.cuda.device_count()}')
+"
+```
+
+Include this output in your bug reports!
+
+## Additional Resources
+
+- **Contributing Guide**: [CONTRIBUTING.md](https://github.com/meta-pytorch/forge/blob/main/CONTRIBUTING.md)
+- **Code of Conduct**: [CODE_OF_CONDUCT.md](https://github.com/meta-pytorch/forge/blob/main/CODE_OF_CONDUCT.md)
+- **Monarch Documentation**: [meta-pytorch.org/monarch](https://meta-pytorch.org/monarch)
+- **vLLM Documentation**: [docs.vllm.ai](https://docs.vllm.ai)
+- **TorchTitan**: [github.com/pytorch/torchtitan](https://github.com/pytorch/torchtitan)
diff --git a/docs/source/index.md b/docs/source/index.md
index c450b4ca7..074fa228f 100644
--- a/docs/source/index.md
+++ b/docs/source/index.md
@@ -1,13 +1,185 @@
-# Welcome to TorchForge documentation!
+# TorchForge Documentation
-**TorchForge** is a PyTorch-native platform specifically designed
-for post-training generative AI models.
+**TorchForge** is a PyTorch-native library for RL post-training and agentic development. Built on the principle that **researchers should write algorithms, not infrastructure**.
-Key Features
-------------
+```{note}
+**Experimental Status:** TorchForge is currently in early development. Expect bugs, incomplete features, and API changes. Please file issues on [GitHub](https://github.com/meta-pytorch/forge) for bug reports and feature requests.
+```
+
+## Why TorchForge?
+
+Reinforcement Learning has become essential to frontier AI - from instruction following and reasoning to complex research capabilities. But infrastructure complexity often dominates the actual research.
+
+TorchForge lets you **express RL algorithms as naturally as pseudocode**, while powerful infrastructure handles distribution, fault tolerance, and optimization underneath.
+
+### Core Design Principles
+
+- **Algorithms, Not Infrastructure**: Write your RL logic without distributed systems code
+- **Any Degree of Asynchrony**: From fully synchronous PPO to fully async off-policy training
+- **Composable Components**: Mix and match proven frameworks (vLLM, TorchTitan) with custom logic
+- **Built on Solid Foundations**: Leverages Monarch's single-controller model for simplified distributed programming
+
+## Foundation: The Technology Stack
+
+TorchForge is built on carefully selected, battle-tested components:
+
+::::{grid} 1 1 2 2
+:gutter: 3
+
+:::{grid-item-card} **Monarch**
+:link: https://meta-pytorch.org/monarch
+
+Single-controller distributed programming framework that orchestrates clusters like you'd program a single machine. Provides actor meshes, fault tolerance, and RDMA-based data transfers.
+
+**Why it matters:** Eliminates SPMD complexity, making distributed RL tractable
+:::
+
+:::{grid-item-card} **vLLM**
+:link: https://docs.vllm.ai
+
+High-throughput, memory-efficient inference engine with PagedAttention and continuous batching.
+
+**Why it matters:** Handles policy generation efficiently at scale
+:::
+
+:::{grid-item-card} **TorchTitan**
+:link: https://github.com/pytorch/torchtitan
+
+Meta's production-grade LLM training platform with FSDP, pipeline parallelism, and tensor parallelism.
+
+**Why it matters:** Battle-tested training infrastructure proven at scale
+:::
+
+:::{grid-item-card} **TorchStore**
+:link: https://github.com/meta-pytorch/torchstore
+
+Distributed, in-memory key-value store for PyTorch tensors built on Monarch, optimized for weight synchronization with automatic DTensor resharding.
+
+**Why it matters:** Solves the weight transfer bottleneck in async RL
+:::
+
+::::
+
+## What You Can Build
+
+::::{grid} 1 1 2 3
+:gutter: 2
+
+:::{grid-item-card} Supervised Fine-Tuning
+Adapt foundation models to specific tasks using labeled data with efficient multi-GPU training.
+:::
+
+:::{grid-item-card} GRPO Training
+Train models with Generalized Reward Policy Optimization for aligning with human preferences.
+:::
+
+:::{grid-item-card} Asynchronous RL
+Continuous rollout generation with non-blocking training for maximum throughput.
+:::
+
+:::{grid-item-card} Code Execution
+Safe, sandboxed code execution environments for RL on coding tasks (RLVR).
+:::
+
+:::{grid-item-card} Tool Integration
+Extensible environment system for agents that interact with tools and APIs.
+:::
+
+:::{grid-item-card} Custom Workflows
+Build your own components and compose them naturally with existing infrastructure.
+:::
+
+::::
+## Requirements at a Glance
+
+Before diving in, check out {doc}`getting_started` and ensure your system meets the requirements.
+
+## Writing RL Code
+
+With TorchForge, your RL logic looks like pseudocode:
+
+```python
+async def generate_episode(dataloader, policy, reward, replay_buffer):
+ # Sample a prompt
+ prompt, target = await dataloader.sample.route()
+
+ # Generate response
+ response = await policy.generate.route(prompt)
+
+ # Score the response
+ reward_value = await reward.evaluate_response.route(
+ prompt=prompt,
+ response=response.text,
+ target=target
+ )
+
+ # Store for training
+ await replay_buffer.add.route(
+ Episode(prompt_ids=response.prompt_ids,
+ response_ids=response.token_ids,
+ reward=reward_value)
+ )
+```
+
+No retry logic, no resource management, no synchronization code - just your algorithm.
+
+## Documentation Paths
+
+Choose your learning path:
+
+::::{grid} 1 1 2 2
+:gutter: 3
+
+:::{grid-item-card} 🚀 Getting Started
+:link: getting_started
+:link-type: doc
+
+Installation, prerequisites, verification, and your first training run.
+
+**Time to first run: ~15 minutes**
+:::
+
+:::{grid-item-card} 💻 Tutorials
+:link: tutorials
+:link-type: doc
+
+Step-by-step guides and practical examples for training with TorchForge.
+
+**For hands-on development**
+:::
+
+:::{grid-item-card} 📖 API Reference
+:link: api
+:link-type: doc
+
+Complete API documentation for customization and extension.
+
+**For deep integration**
+:::
+
+::::
+
+## Validation & Partnerships
+
+TorchForge has been validated in real-world deployments:
+
+- **Stanford Collaboration**: Integration with the Weaver weak verifier project, training models that hill-climb on challenging reasoning benchmarks (MATH, GPQA)
+- **CoreWeave**: Large-scale training on 512 H100 GPU clusters with smooth, efficient performance
+- **Scale**: Tested across hundreds of GPUs with continuous rollouts and asynchronous training
+
+## Community
+
+- **GitHub**: [meta-pytorch/forge](https://github.com/meta-pytorch/forge)
+- **Issues**: [Report bugs and request features](https://github.com/meta-pytorch/forge/issues)
+- **Contributing**: [CONTRIBUTING.md](https://github.com/meta-pytorch/forge/blob/main/CONTRIBUTING.md)
+- **Code of Conduct**: [CODE_OF_CONDUCT.md](https://github.com/meta-pytorch/forge/blob/main/CODE_OF_CONDUCT.md)
+
+```{tip}
+Before starting significant work, signal your intention in the issue tracker to coordinate with maintainers.
+```
* **Post-Training Focus**: Specializes in techniques
- like Supervised Fine-Tuning (SFT) and Generalized Reward Policy Optimization (GRPO)
+ like Supervised Fine-Tuning (SFT) and Group Relative Policy Optimization (GRPO)
* **PyTorch Integration**: Built natively on PyTorch with
dependencies on [PyTorch nightly](https://pytorch.org/get-started/locally/),
[Monarch](https://meta-pytorch.org/monarch), [vLLM](https://docs.vllm.ai/en/latest/),
@@ -18,19 +190,19 @@ Key Features
like Llama3 8B and Qwen3.1 7B
```{toctree}
-:maxdepth: 1
-:caption: Contents:
+:maxdepth: 2
+:caption: Documentation
getting_started
-concepts
-usage
tutorials
api
-faq
```
-## Indices and tables
+## Indices
+
+* {ref}`genindex` - Index of all documented objects
+* {ref}`modindex` - Python module index
+
+---
-* {ref}`genindex`
-* {ref}`modindex`
-* {ref}`search`
+**License**: BSD 3-Clause | **GitHub**: [meta-pytorch/forge](https://github.com/meta-pytorch/forge)
diff --git a/docs/source/metric_logging.md b/docs/source/metric_logging.md
new file mode 100644
index 000000000..e2f9fe27e
--- /dev/null
+++ b/docs/source/metric_logging.md
@@ -0,0 +1 @@
+```{include} ../../src/forge/observability/README.md
diff --git a/docs/source/tutorial_sources/README.txt b/docs/source/tutorial_sources/README.txt
new file mode 100644
index 000000000..0f59efa01
--- /dev/null
+++ b/docs/source/tutorial_sources/README.txt
@@ -0,0 +1,5 @@
+Tutorials
+=========
+
+This gallery contains tutorials and examples to help you get started with TorchForge.
+Each tutorial demonstrates specific features and use cases with practical examples.
diff --git a/docs/source/tutorial_sources/template_tutorial.py b/docs/source/tutorial_sources/template_tutorial.py
new file mode 100644
index 000000000..4018aa1b1
--- /dev/null
+++ b/docs/source/tutorial_sources/template_tutorial.py
@@ -0,0 +1,91 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Template Tutorial
+=================
+
+**Author:** `FirstName LastName `_
+
+.. grid:: 2
+
+ .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
+ :class-card: card-prerequisites
+
+ * Item 1
+ * Item 2
+ * Item 3
+
+ .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
+ :class-card: card-prerequisites
+
+ * PyTorch v2.0.0
+ * GPU ???
+ * Other items 3
+
+
+To test your tutorial locally, you can do one of the following:
+
+* You can control specific files that generate the results by using
+ ``GALLERY_PATTERN`` environment variable. The GALLERY_PATTERN variable
+ respects regular expressions.
+ For example to run only ``neural_style_transfer_tutorial.py``,
+ use the following command:
+
+ .. code-block:: sh
+
+ GALLERY_PATTERN="neural_style_transfer_tutorial.py" make html
+
+ or
+
+ .. code-block:: sh
+
+ GALLERY_PATTERN="neural_style_transfer_tutorial.py" sphinx-build . _build
+
+* Make a copy of this repository and add only your
+ tutorial to the `beginner_source` directory removing all other tutorials.
+ Then run ``make html``.
+
+Verify that all outputs were generated correctly in the created HTML.
+"""
+
+#########################################################################
+# Overview
+# --------
+#
+# Describe Why is this topic important? Add Links to relevant research papers.
+#
+# This tutorial walks you through the process of....
+#
+# Steps
+# -----
+#
+# Example code (the output below is generated automatically):
+#
+import torch
+
+x = torch.rand(5, 3)
+print(x)
+
+######################################################################
+# (Optional) Additional Exercises
+# -------------------------------
+#
+# Add additional practice exercises for users to test their knowledge.
+# Example: `NLP from Scratch `__.
+#
+
+######################################################################
+# Conclusion
+# ----------
+#
+# Summarize the steps and concepts covered. Highlight key takeaways.
+#
+# Further Reading
+# ---------------
+#
+# * Link1
+# * Link2
diff --git a/docs/source/tutorial_sources/zero-to-forge/1_RL_and_Forge_Fundamentals.md b/docs/source/tutorial_sources/zero-to-forge/1_RL_and_Forge_Fundamentals.md
new file mode 100644
index 000000000..37314831c
--- /dev/null
+++ b/docs/source/tutorial_sources/zero-to-forge/1_RL_and_Forge_Fundamentals.md
@@ -0,0 +1,410 @@
+# Part 1: RL Fundamentals - Using TorchForge Terminology
+
+## Core RL Components in TorchForge
+
+Let's start with a simple math tutoring example to understand RL concepts with the exact names TorchForge uses:
+
+### The Toy Example: Teaching Math
+
+```mermaid
+graph TD
+ subgraph Example["Math Tutoring RL Example"]
+ Dataset["Dataset: math problems"]
+ Policy["Policy: student AI"]
+ Reward["Reward Model:
+ scores answers"]
+ Reference["Reference Model:
+ baseline"]
+ ReplayBuffer["Replay Buffer: stores experiences"]
+ Trainer["Trainer: improves student"]
+ end
+
+ Dataset --> Policy
+ Policy --> Reward
+ Policy --> Reference
+ Reward --> ReplayBuffer
+ Reference --> ReplayBuffer
+ ReplayBuffer --> Trainer
+ Trainer --> Policy
+
+ style Policy fill:#4CAF50,stroke:#fff,stroke-width:2px
+ style Reward fill:#FF9800,stroke:#fff,stroke-width:2px
+ style Trainer fill:#E91E63,stroke:#fff,stroke-width:2px
+
+ linkStyle default stroke:#888,stroke-width:2px
+```
+
+### RL Components Defined (TorchForge Names)
+
+1. **Dataset**: Provides questions/prompts (like "What is 2+2?")
+2. **Policy**: The AI being trained (generates answers like "The answer is 4")
+3. **Reward Model**: Evaluates answer quality (gives scores like 0.95)
+4. **Reference Model**: Original policy copy (prevents drift from baseline)
+5. **Replay Buffer**: Stores experiences (question + answer + score)
+6. **Trainer**: Updates the policy weights based on experiences
+
+### The RL Learning Flow
+
+```python
+# CONCEPTUAL EXAMPLE - see apps/grpo/main.py for GRPO Code
+
+def conceptual_rl_step():
+ # 1. Get a math problem
+ question = dataset.sample() # "What is 2+2?"
+
+ # 2. Student generates answer
+ answer = policy.generate(question) # "The answer is 4"
+
+ # 3. Teacher grades it
+ score = reward_model.evaluate(question, answer) # 0.95
+
+ # 4. Compare to original student
+ baseline = reference_model.compute_logprobs(question, answer)
+
+ # 5. Store the experience
+ experience = Episode(question, answer, score, baseline)
+ replay_buffer.add(experience)
+
+ # 6. When enough experiences collected, improve student
+ batch = replay_buffer.sample(curr_policy_version=0)
+ if batch is not None:
+ trainer.train_step(batch) # Student gets better!
+
+# 🔄 See complete working example below with actual TorchForge service calls
+```
+
+## From Concepts to TorchForge Services
+
+Here's the key insight: **Each RL component becomes a TorchForge service**. The toy example above maps directly to TorchForge:
+
+```mermaid
+graph LR
+ subgraph Concepts["RL Concepts"]
+
+ C1["Dataset"]
+ C2["Policy"]
+ C3["Reward Model"]
+ C4["Reference Model"]
+ C5["Replay Buffer"]
+ C6["Trainer"]
+ end
+
+ subgraph Services["TorchForge Services (Real Classes)"]
+
+ S1["DatasetActor"]
+ S2["Generator"]
+ S3["RewardActor"]
+ S4["ReferenceModel"]
+ S5["ReplayBuffer"]
+ S6["TitanTrainer"]
+ end
+
+ C1 --> S1
+ C2 --> S2
+ C3 --> S3
+ C4 --> S4
+ C5 --> S5
+ C6 --> S6
+
+ style C2 fill:#4CAF50
+ style S2 fill:#4CAF50
+ style C3 fill:#FF9800
+ style S3 fill:#FF9800
+```
+
+### RL Step with TorchForge Services
+
+Let's look at the example from above again, but this time we would use the names from TorchForge:
+
+```python
+# Conceptual Example
+
+async def conceptual_forge_rl_step(services, step):
+ # 1. Get a math problem - Using actual DatasetActor API
+ sample = await services['dataloader'].sample.call_one()
+ question, target = sample["request"], sample["target"]
+
+ # 2. Student generates answer - Using actual Policy API
+ responses = await services['policy'].generate.route(prompt=question)
+ answer = responses[0].text
+
+ # 3. Teacher grades it - Using actual RewardActor API
+ score = await services['reward_actor'].evaluate_response.route(
+ prompt=question, response=answer, target=target
+ )
+
+ # 4. Compare to baseline - Using actual ReferenceModel API
+ # Note: ReferenceModel.forward requires input_ids, max_req_tokens, return_logprobs
+ ref_logprobs = await services['ref_model'].forward.route(
+ input_ids, max_req_tokens, return_logprobs=True
+ )
+
+ # 5. Store experience - Using actual Episode structure from apps/grpo/main.py
+ episode = create_episode_from_response(responses[0], score, ref_logprobs, step)
+ await services['replay_buffer'].add.call_one(episode)
+
+ # 6. Improve student - Using actual training pattern
+ batch = await services['replay_buffer'].sample.call_one(
+ curr_policy_version=step
+ )
+ if batch is not None:
+ inputs, targets = batch
+ loss = await services['trainer'].train_step.call(inputs, targets)
+ return loss
+```
+
+**Key difference**: Same RL logic, but each component is now a distributed, fault-tolerant, auto-scaling service.
+
+Did you realise-we are not worrying about any Infra code here! TorchForge Automagically handles the details behind the scenes and you can focus on writing your RL Algorithms!
+
+
+## Why This Matters: Traditional ML Infrastructure Fails
+
+### The Infrastructure Challenge
+
+Our simple RL loop above has complex requirements:
+
+#### Problem 1: Different Resource Needs
+
+| Component | Resource Needs | Scaling Strategy |
+|-----------|----------------|------------------|
+| **Policy** (Student AI) | Large GPU memory | Multiple replicas for throughput |
+| **Reward Heuristic** (Teacher) | Small compute | CPU or small GPU |
+| **Trainer** (Tutor) | Massive GPU compute | Distributed training |
+| **Dataset** (Question Bank) | CPU intensive I/O | High memory bandwidth |
+
+### Problem 2: Complex Interdependencies
+
+```mermaid
+graph LR
+ A["Policy: Student AI
+ 'What is 2+2?' →
+ 'The answer is 4'"]
+ B["Reward: Teacher
+ Scores answer: 0.95"]
+ C["Reference: Original Student
+ Provides baseline comparison"]
+ D["Replay Buffer: Notebook
+ Stores: question
+ + answer
+ + score"]
+ E["Trainer: Tutor
+ Improves student
+ using experiences"]
+
+ A --> B
+ A --> C
+ B --> D
+ C --> D
+ D --> E
+ E --> A
+
+ style A fill:#4CAF50
+ style B fill:#FF9800
+ style C fill:#2196F3
+ style D fill:#8BC34A
+ style E fill:#E91E63
+```
+
+Each step has different:
+- **Latency requirements**: Policy inference needs low latency (each episode waits), training can batch multiple episodes together
+- **Scaling patterns**: Need N policy replicas to keep trainer busy, plus different sharding strategies (tensor parallel for training vs replicated inference)
+- **Failure modes**: Any component failure cascades to halt the entire pipeline (TorchForge prevents this with automatic failover)
+- **Resource utilization**: GPUs for inference/training, CPUs for data processing
+
+### Problem 3: The Coordination Challenge
+
+Unlike supervised learning where you process independent batches, RL requires coordination:
+
+```python
+# While this does work, it creates bottlenecks and resource waste
+def naive_rl_step():
+ # Policy waits idle while reward model works
+ response = policy_model.generate(prompt) # GPU busy
+ reward = reward_model.evaluate(prompt, response) # Policy GPU idle
+
+ # Training waits for single episode
+ loss = compute_loss(response, reward) # Batch size = 1, inefficient
+
+ # Everything stops if any component fails
+ if policy_fails or reward_fails or trainer_fails:
+ entire_system_stops()
+```
+
+## Enter TorchForge: RL-Native Architecture
+
+TorchForge solves these problems by treating each RL component as an **independent, distributed unit** - some as fault-tolerant services (like Policy inference where failures are easy to handle), others as actors (like Trainers where recovery semantics differ)
+
+Let's see how core RL concepts map to TorchForge components (you'll notice a mix of `.route()` for services and `.call_one()` for actors - we cover when to use each in Part 2):
+
+**Quick API Reference:** (covered in detail in Part 2: Service Communication Patterns)
+- `.route()` - Send request to any healthy replica in a service (load balanced)
+- `.call_one()` - Send request to a single actor instance
+- `.fanout()` - Send request to ALL replicas in a service
+
+```python
+async def real_rl_training_step(services, step):
+ """Single RL step using verified TorchForge APIs"""
+
+ # 1. Environment interaction - Using actual DatasetActor API
+ sample = await services['dataloader'].sample.call_one()
+ prompt, target = sample["request"], sample["target"]
+
+ responses = await services['policy'].generate.route(prompt)
+
+ # 2. Reward computation - Using actual RewardActor API
+ score = await services['reward_actor'].evaluate_response.route(
+ prompt=prompt, response=responses[0].text, target=target
+ )
+
+ # 3. Get reference logprobs - Using actual ReferenceModel API
+ # Note: ReferenceModel requires full input_ids tensor, not just tokens
+ input_ids = torch.cat([responses[0].prompt_ids, responses[0].token_ids])
+ ref_logprobs = await services['ref_model'].forward.route(
+ input_ids.unsqueeze(0), max_req_tokens=512, return_logprobs=True
+ )
+
+ # 4. Experience storage - Using actual Episode pattern from GRPO
+ episode = create_episode_from_response(responses[0], score, ref_logprobs, step)
+ await services['replay_buffer'].add.call_one(episode)
+
+ # 5. Learning - Using actual trainer pattern
+ batch = await services['replay_buffer'].sample.call_one(
+ curr_policy_version=step
+ )
+ if batch is not None:
+ inputs, targets = batch # GRPO returns (inputs, targets) tuple
+ loss = await services['trainer'].train_step.call(inputs, targets)
+
+ # 6. Policy synchronization - Using actual weight update pattern
+ await services['trainer'].push_weights.call(step + 1)
+ await services['policy'].update_weights.fanout(step + 1)
+
+ return loss
+```
+
+**Key insight**: Each line of RL pseudocode becomes a service call. The complexity of distribution, scaling, and fault tolerance is hidden behind these simple interfaces.
+
+## What Makes This Powerful
+
+### Automatic Resource Management
+```python
+responses = await policy.generate.route(prompt=question)
+answer = responses[0].text # responses is list[Completion]
+```
+
+TorchForge handles behind the scenes:
+- Routing to least loaded replica
+- GPU memory management
+- Batch optimization
+- Failure recovery
+- Auto-scaling based on demand
+
+### Independent Scaling
+```python
+
+from forge.actors.generator import Generator as Policy
+from forge.actors.replay_buffer import ReplayBuffer
+from forge.actors.reference_model import ReferenceModel
+from forge.actors.trainer import TitanTrainer
+from apps.grpo.main import DatasetActor, RewardActor, ComputeAdvantages
+from forge.data.rewards import MathReward, ThinkingReward
+import asyncio
+import torch
+
+model = "Qwen/Qwen3-1.7B"
+group_size = 1
+
+(
+ dataloader,
+ policy,
+ trainer,
+ replay_buffer,
+ compute_advantages,
+ ref_model,
+ reward_actor,
+) = await asyncio.gather(
+ # Dataset actor (CPU)
+ DatasetActor.options(procs=1).as_actor(
+ path="openai/gsm8k",
+ revision="main",
+ data_split="train",
+ streaming=True,
+ model=model,
+ ),
+ # Policy service with GPU
+ Policy.options(procs=1, with_gpus=True, num_replicas=1).as_service(
+ engine_config={
+ "model": model,
+ "tensor_parallel_size": 1,
+ "pipeline_parallel_size": 1,
+ "enforce_eager": False
+ },
+ sampling_config={
+ "n": group_size,
+ "max_tokens": 16,
+ "temperature": 1.0,
+ "top_p": 1.0
+ }
+ ),
+ # Trainer actor with GPU
+ TitanTrainer.options(procs=1, with_gpus=True).as_actor(
+ # Trainer config would come from YAML in real usage
+ model={"name": "qwen3", "flavor": "1.7B", "hf_assets_path": f"hf://{model}"},
+ optimizer={"name": "AdamW", "lr": 1e-5},
+ training={"local_batch_size": 2, "seq_len": 2048}
+ ),
+ # Replay buffer (CPU)
+ ReplayBuffer.options(procs=1).as_actor(
+ batch_size=2,
+ max_policy_age=1,
+ dp_size=1
+ ),
+ # Advantage computation (CPU)
+ ComputeAdvantages.options(procs=1).as_actor(),
+ # Reference model with GPU
+ ReferenceModel.options(procs=1, with_gpus=True).as_actor(
+ model={"name": "qwen3", "flavor": "1.7B", "hf_assets_path": f"hf://{model}"},
+ training={"dtype": "bfloat16"}
+ ),
+ # Reward actor (CPU)
+ RewardActor.options(procs=1, num_replicas=1).as_service(
+ reward_functions=[MathReward(), ThinkingReward()]
+ )
+ )
+```
+
+**TorchForge Components: Services vs Actors**
+
+TorchForge has two types of distributed components:
+- **Services**: Multiple replicas with automatic load balancing (like Policy, RewardActor)
+- **Actors**: Single instances that handle their own internal distribution (like TitanTrainer, ReplayBuffer)
+
+We cover this distinction in detail in Part 2, but for now this explains the scaling patterns:
+- Policy service: num_replicas=8 for high inference demand
+- RewardActor service: num_replicas=16 for parallel evaluation
+- TitanTrainer actor: Single instance with internal distributed training
+
+
+### Fault Tolerance
+```python
+# If a policy replica fails:
+responses = await policy.generate.route(prompt=question)
+answer = responses[0].text
+# -> TorchForge automatically routes to healthy replica
+# -> Failed replica respawns in background
+# -> No impact on training loop
+
+# If reward service fails:
+score = await reward_actor.evaluate_response.route(
+ prompt=question, response=answer, target=target
+)
+```
+
+- Retries on different replica automatically
+- Graceful degradation if all replicas fail
+- System continues (may need application-level handling)
+
+This is fundamentally different from monolithic RL implementations where any component failure stops everything!
+
+In the next Section, we will go a layer deeper and learn how ForgeServices work. Continue to [Part 2 here](./2_Forge_Internals.md)
diff --git a/docs/source/tutorial_sources/zero-to-forge/2_Forge_Internals.md b/docs/source/tutorial_sources/zero-to-forge/2_Forge_Internals.md
new file mode 100644
index 000000000..cc6cfeda3
--- /dev/null
+++ b/docs/source/tutorial_sources/zero-to-forge/2_Forge_Internals.md
@@ -0,0 +1,688 @@
+# Part 2: Peeling Back the Abstraction - What Are Services?
+
+We highly recommend reading [Part 1](./1_RL_and_Forge_Fundamentals) before this, it explains RL Concepts and how they land in TorchForge.
+
+Now that you see the power of the service abstraction, let's understand what's actually happening under the hood, Grab your chai!
+
+## Service Anatomy: Beyond the Interface
+
+When you call `await policy_service.generate(question)`, here's what actually happens:
+
+(Don't worry, we will understand Services right in the next section!)
+
+```mermaid
+graph TD
+ Call["Your Code:
+ await policy_service
+ .generate.route"]
+
+ subgraph ServiceLayer["Service Layer"]
+ Proxy["Service Proxy:
+ Load balancing
+ Health checking"]
+ LB["Load Balancer:
+ Replica selection
+ Circuit breaker"]
+ end
+
+ subgraph Replicas["Replica Management"]
+ R1["Replica 1:
+ GPU 0, Healthy"]
+ R2["Replica 2:
+ GPU 1, Overloaded"]
+ R3["Replica 3:
+ GPU 2, Failed"]
+ R4["Replica 4:
+ GPU 3, Healthy"]
+ end
+
+ subgraph Compute["Actual Computation"]
+ Actor["Policy Actor:
+ vLLM engine,
+ Model weights,
+ KV cache"]
+ end
+
+ Call --> Proxy
+ Proxy --> LB
+ LB --> R1
+ LB -.-> R2
+ LB -.-> R3
+ LB --> R4
+ R1 --> Actor
+ R4 --> Actor
+
+ style Call fill:#4CAF50
+ style LB fill:#FF9800
+ style R3 fill:#F44336
+ style Actor fill:#9C27B0
+```
+
+## Service Components Deep Dive
+
+### 1. Real Service Configuration
+
+Here's the actual ServiceConfig from TorchForge source code:
+
+```python
+# Configuration pattern from apps/grpo/main.py:
+Policy.options(
+ procs=1, # Processes per replica
+ num_replicas=4, # Number of replicas
+ with_gpus=True # Allocate GPUs
+ # Other available options:
+ # hosts=None # the number of remote hosts used per replica
+)
+```
+
+### 2. Real Service Creation
+
+Services are created using the `.options().as_service()` pattern from the actual GRPO implementation:
+
+The service creation automatically handles:
+- Spawning actor replicas across processes/GPUs
+- Load balancing with .route() method for services
+- Health monitoring and failure recovery
+- Message routing and serialization
+
+```python
+from forge.actors.generator import Generator as Policy
+
+model = "Qwen/Qwen3-1.7B"
+
+policy = await Policy.options(
+ procs=1,
+ with_gpus=True,
+ num_replicas=1
+).as_service(
+ engine_config={
+ "model": model,
+ "tensor_parallel_size": 1,
+ "pipeline_parallel_size": 1,
+ "enforce_eager": False
+ },
+ sampling_config={
+ "n": 1,
+ "max_tokens": 16,
+ "temperature": 1.0,
+ "top_p": 1.0
+ }
+)
+
+prompt = "What is 3 + 5?"
+responses = await policy.generate.route(prompt)
+print(f"Response: {responses[0].text}")
+
+# Cleanup when done
+await policy.shutdown()
+```
+
+### 3. How Services Actually Work
+
+TorchForge services are implemented as ServiceActors that manage collections of your ForgeActor replicas:
+
+When you call `.as_service()`, TorchForge creates a `ServiceInterface` that manages N replicas of your `ForgeActor` class and gives you methods like `.route()`, `.fanout()`, etc.
+
+```python
+# Your code sees this simple interface:
+responses = await policy.generate.route(prompt=prompt)
+# But TorchForge handles all the complexity of replica management, load balancing, and fault tolerance
+```
+
+## Communication Patterns: Quick Reference
+
+**API Summary:**
+- `.route()` - Send request to any healthy replica in a service (load balanced)
+- `.call_one()` - Send request to a single actor instance
+- `.fanout()` - Send request to ALL replicas in a service
+
+```mermaid
+graph LR
+ subgraph Request["Your Request"]
+ Code["await service
+ .method.ADVERB()"]
+ end
+
+ subgraph Patterns["Communication Patterns"]
+ Route[".route()
+ → One healthy replica"]
+ CallOne[".call_one()
+ → Single actor"]
+ Fanout[".fanout()
+ → ALL replicas"]
+ end
+
+ subgraph Replicas["Replicas/Actors"]
+ R1["Replica 1"]
+ R2["Replica 2"]
+ R3["Replica 3"]
+ A1["Actor"]
+ end
+
+ Code --> Route
+ Code --> CallOne
+ Code --> Fanout
+
+ Route --> R2
+ CallOne --> A1
+ Fanout --> R1
+ Fanout --> R2
+ Fanout --> R3
+
+ style Route fill:#4CAF50
+ style CallOne fill:#FF9800
+ style Fanout fill:#9C27B0
+```
+
+## Deep Dive: Service Communication Patterns
+
+These communication patterns (\"adverbs\") determine how your service calls are routed to replicas. Understanding when to use each pattern is key to effective TorchForge usage.
+
+### 1. `.route()` - Load Balanced Single Replica
+
+**When to use**: Normal request routing where any replica can handle the request.
+
+```python
+responses = await policy.generate.route(prompt=question)
+answer = responses[0].text # Extract text from Completion object
+```
+
+Behind the scenes:
+1. Health check eliminates failed replicas
+2. Load balancer picks replica (currently round robin, configurable balancers coming soon)
+3. Request routes to that specific replica
+4. Automatic retry on different replica if failure
+
+**Performance characteristics**:
+- **Latency**: Lowest (single network hop)
+- **Throughput**: Limited by single replica capacity
+- **Fault tolerance**: Automatic failover to other replicas
+
+**Critical insight**: `.route()` is your default choice for stateless operations in TorchForge services.
+
+### 2. `.fanout()` - Broadcast with Results Collection
+
+**When to use**: You need responses from ALL replicas.
+
+```python
+# Get version from all policy replicas
+current_versions = await policy.get_version.fanout()
+# Returns: [version_replica_1, version_replica_2, ...]
+
+# Update weights on all replicas
+await policy.update_weights.fanout(new_policy_version)
+# Broadcasts to all replicas simultaneously
+```
+
+**Performance characteristics**:
+- **Latency**: Slowest replica determines total latency
+- **Throughput**: Network bandwidth × number of replicas
+- **Fault tolerance**: Fails if ANY replica fails (unless configured otherwise)
+
+**Critical gotcha**: Don't use `.fanout()` for high-frequency operations - it contacts all replicas.
+
+### 3. Streaming Operations - Custom Implementation Pattern
+
+**When to use**: You want to process results as they arrive, not wait for all.
+
+```python
+# CONCEPTUAL - Streaming requires custom implementation in your training loop
+# The basic ReplayBuffer doesn't have built-in streaming methods
+# Pattern from apps/grpo/main.py continuous training:
+
+while training:
+ # This is the real API call pattern
+ batch = await replay_buffer.sample.call_one(curr_policy_version=step)
+ if batch is not None:
+ # Process batch immediately
+ loss = await trainer.train_step.call_one(batch)
+ print(f"Training loss: {loss}")
+ else:
+ await asyncio.sleep(0.1) # Wait for more data
+```
+
+**Performance characteristics**:
+- **Latency**: Process first result immediately
+- **Throughput**: Non-blocking async operations (much higher than waiting for full batches)
+- **Fault tolerance**: Continues if some replicas fail
+
+**Critical insight**: This is essential for high-throughput RL where you can't wait for batches.
+
+### 3. Service Sessions for Stateful Operations
+
+**When to use**: When you need multiple calls to hit the same replica (like KV cache preservation).
+
+**What are sticky sessions?** A session ensures all your service calls within the `async with` block go to the same replica, instead of being load-balanced across different replicas.
+
+```python
+# This Counter example demonstrates the difference between regular routing and sessions
+
+from forge.controller import ForgeActor
+from monarch.actor import endpoint
+
+class ForgeCounter(ForgeActor):
+ def __init__(self, initial_value: int):
+ self.value = initial_value
+
+ @endpoint
+ def increment(self) -> int:
+ self.value += 1
+ return self.value
+
+ @endpoint
+ def get_value(self) -> int:
+ return self.value
+
+ @endpoint
+ async def reset(self):
+ self.value = 0
+
+counter_service = await ForgeCounter.options(
+ procs=1, num_replicas=4
+).as_service(initial_value=0)
+
+# WITHOUT SESSIONS: Each .route() call goes to a different replica
+await counter_service.increment.route() # Might go to replica 2
+await counter_service.increment.route() # Might go to replica 1
+await counter_service.increment.route() # Might go to replica 3
+
+results = await counter_service.increment.fanout() # Get from all replicas
+print(f"All replica values: {results}")
+# Output: All replica values: [1, 2, 1, 1] - Each replica has different state!
+```
+
+The problem: each `.route()` call can go to different replicas, creating inconsistent state.
+
+```python
+# WITH SESSIONS: All calls go to the SAME replica
+print("\nUsing sticky sessions:")
+async with counter_service.session(): # Creates a session that picks one replica
+ await counter_service.reset.route() # Uses .route() within session
+ print(await counter_service.increment.route()) # 1
+ print(await counter_service.increment.route()) # 2
+ print(await counter_service.increment.route()) # 3
+
+ final_value = await counter_service.get_value.route()
+ print(f"Final value on this replica: {final_value}") # 3
+
+# Output:
+# Using sticky sessions:
+# 1
+# 2
+# 3
+# Final value on this replica: 3
+
+# Same pattern works with Policy for multi-turn conversations:
+# async with policy.session():
+# response1 = await policy.generate.route(turn1)
+# full_prompt = turn1 + response1[0].text + turn2
+# response2 = await policy.generate.route(full_prompt)
+# # Both calls hit same replica, preserving KV cache
+
+# Cleanup
+await counter_service.shutdown()
+```
+
+**Performance impact**: Critical for maintaining KV cache in multi-turn conversations.
+
+## Deep Dive: State Management Reality
+
+The most complex challenge in distributed RL is maintaining state consistency while maximizing performance.
+
+### The KV Cache Problem
+
+**The challenge**: Policy inference is much faster with KV cache, but cache is tied to specific conversation history.
+
+```python
+# This breaks KV cache optimization:
+async def naive_multi_turn():
+ # Each call might go to different replica = cache miss
+ response1 = await policy_service.generate.choose(question1)
+ response2 = await policy_service.generate.choose(question1 + response1) # Cache miss!
+ response3 = await policy_service.generate.choose(conversation_so_far) # Cache miss!
+```
+
+**The solution**: Sticky sessions ensure all calls go to same replica.
+
+```python
+async def optimized_multi_turn():
+ async with policy.session():
+ # All calls guaranteed to hit same replica = cache hits
+ response1 = await policy.generate.route(prompt=question1)
+ full_prompt = question1 + response1[0].text
+ response2 = await policy.generate.route(prompt=full_prompt) # Cache hit!
+ conversation = full_prompt + response2[0].text
+ response3 = await policy.generate.route(prompt=conversation) # Cache hit!
+
+ # Session ends, replica can be garbage collected or reused
+```
+
+**Performance impact**: Maintaining KV cache across turns avoids recomputing previous tokens.
+
+### Replay Buffer Consistency
+
+**The challenge**: Multiple trainers and experience collectors reading/writing concurrently.
+
+**Real TorchForge approach**: The ReplayBuffer actor handles concurrency internally:
+
+```python
+# TorchForge ReplayBuffer endpoints (verified from source code)
+# Add episodes (thread-safe by actor model)
+await replay_buffer.add.call_one(episode) # .choose() would work too, but .call_one() clarifies it's a singleton actor not ActorMesh
+
+# Sample batches for training
+batch = await replay_buffer.sample.call_one(
+ curr_policy_version=step_number,
+ batch_size=None # Optional parameter, uses default from config
+)
+
+# Additional methods available:
+# await replay_buffer.clear.call_one() # Clear buffer
+# await replay_buffer.evict.call_one(curr_policy_version) # Remove old episodes
+# state = await replay_buffer.state_dict.call_one() # Get state for checkpointing
+```
+
+**Critical insight**: The actor model provides natural thread safety - each actor processes messages sequentially.
+
+### Weight Synchronization Strategy
+
+**The challenge**: Trainer updates policy weights, but policy service needs those weights.
+
+```python
+# TorchForge weight synchronization pattern from apps/grpo/main.py
+async def real_weight_sync(trainer, policy, step):
+ # Trainer pushes weights to TorchStore with version number
+ await trainer.push_weights.call_one(policy_version=step + 1)
+
+ # Policy service updates to new version from TorchStore
+ # Use .fanout() to update ALL policy replicas
+ await policy.update_weights.fanout(policy_version=step + 1)
+
+# Check current policy version
+current_version = await policy.get_version.route()
+print(f"Current policy version: {current_version}")
+```
+
+## Deep Dive: Asynchronous Coordination Patterns
+
+**The real challenge**: Different services run at different speeds, but TorchForge's service abstraction handles the coordination complexity.
+
+### The TorchForge Approach: Let Services Handle Coordination
+
+Instead of manual coordination, TorchForge services handle speed mismatches automatically:
+
+```python
+from apps.grpo.main import Episode, Group
+
+async def simple_rl_step():
+
+ # ===== Generate a rollout =====
+ sample = await dataloader.sample.call_one() # DatasetActor is an actor, not service
+ prompt, target = sample["request"], sample["target"] # Correct field names
+
+ print(f"Prompt: {prompt}")
+ print(f"Target: {target}")
+
+ actions = await policy.generate.route(prompt=prompt) # Policy is a service
+ print(f"Policy response: {actions[0].text}")
+
+ # Create input tensor for reference model (requires full context)
+ input_ids = torch.cat([actions[0].prompt_ids, actions[0].token_ids])
+ ref_logprobs = await ref_model.forward.route(
+ input_ids.unsqueeze(0), max_req_tokens=512, return_logprobs=True
+ )
+ reward = await reward_actor.evaluate_response.route( # RewardActor is a service
+ prompt=prompt,
+ response=actions[0].text,
+ target=target
+ )
+ print(f"Reward: {reward}")
+
+ # Create episode using actual GRPO Episode structure
+ episode = Episode(
+ episode_id="0",
+ request=prompt,
+ policy_version=0,
+ pad_id=tokenizer.pad_token_id,
+ request_len=512,
+ response_len=512,
+ target=target
+ )
+
+ # Add response data
+ episode.response = actions[0].text
+ episode.request_tokens = actions[0].prompt_ids.tolist()
+ episode.response_tokens = actions[0].token_ids.tolist()
+ episode.ref_logprobs = ref_logprobs[0] # Extract from batch dimension
+ episode.reward = reward
+
+ # Compute advantages using actual ComputeAdvantages actor
+ group = Group.new_group(0, 1, prompt, 0, tokenizer.pad_token_id, 512, 512, target)
+ group.episodes[0] = episode
+ advantages = await compute_advantages.compute.call_one(group) # ComputeAdvantages is an actor
+ episode.advantage = advantages[0]
+ print(f"Advantage: {advantages[0]}")
+ await replay_buffer.add.call_one(episode) # ReplayBuffer is an actor
+ print("Episode stored in replay buffer")
+
+ # ===== Train on the batch =====
+ batch = await replay_buffer.sample.call_one(curr_policy_version=0)
+ if batch is not None:
+ print("Training on batch...")
+ inputs, targets = batch # GRPO returns (inputs, targets) tuple
+ loss = await trainer.train_step.call(inputs, targets) # TitanTrainer is an actor
+ print(f"Training loss: {loss}")
+ return loss
+ else:
+ print("Not enough data in buffer yet")
+ return None
+
+# Note: This simplified example assumes tokenizer and services are already initialized
+for step in range(10):
+ print(f"\n--- RL Step {step + 1} ---")
+ loss = await simple_rl_step()
+ if loss:
+ print(f"Step {step + 1} complete, loss: {loss:.4f}")
+ else:
+ print(f"Step {step + 1} complete, building buffer...")
+```
+
+### Handling Speed Mismatches with Service Scaling
+
+**The insight**: Scale services independently based on their bottlenecks.
+
+```python
+# Scale fast services with more replicas
+policy = await Policy.options(
+ procs=1, num_replicas=8, with_gpus=True # Many replicas for high throughput
+).as_service(
+ engine_config={"model": model_name, "tensor_parallel_size": 1}
+)
+
+# Reward evaluation might be CPU-bound
+reward_actor = await RewardActor.options(
+ procs=1, num_replicas=16, with_gpus=False # More CPU replicas
+).as_service(
+ reward_functions=[MathReward()]
+)
+
+# Training needs fewer but more powerful replicas
+trainer = await TitanTrainer.options(
+ procs=1, with_gpus=True # Fewer but GPU-heavy
+).as_actor( # Trainer typically uses .as_actor() not .as_service()
+ model={"name": "qwen3", "flavor": "1.7B"},
+ optimizer={"name": "AdamW", "lr": 1e-5}
+)
+```
+
+## Service Implementation Example
+
+Let's see how a reward service is actually implemented:
+
+```python
+# Exact RewardActor from apps/grpo/main.py
+
+from forge.controller import ForgeActor
+from monarch.actor import endpoint
+from forge.data.rewards import MathReward, ThinkingReward
+
+# class definition from apps/grpo/main.py
+class RewardActor(ForgeActor):
+ def __init__(self, reward_functions: list):
+ self.reward_functions = reward_functions
+
+ @endpoint
+ async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
+ """Evaluate response quality using multiple reward functions"""
+ total_reward = 0.0
+
+ for reward_fn in self.reward_functions:
+ # Each reward function contributes to total score
+ reward = reward_fn(prompt, response, target)
+ total_reward += reward
+
+ # Return average reward across all functions
+ return total_reward / len(self.reward_functions) if self.reward_functions else 0.0
+
+reward_actor = await RewardActor.options(
+ procs=1, num_replicas=1
+).as_service(
+ reward_functions=[MathReward(), ThinkingReward()]
+)
+
+prompt = "What is 15% of 240?"
+response = "15% of 240 is 36"
+target = "36"
+
+score = await reward_actor.evaluate_response.route(
+ prompt=prompt,
+ response=response,
+ target=target
+)
+print(f"Reward score: {score}") # Usually around 1.0 for correct math answers
+
+# For production scaling - increase num_replicas for parallel evaluation:
+# RewardActor.options(procs=1, num_replicas=16) # 16 parallel evaluators
+
+# Cleanup when done
+await reward_actor.shutdown()
+```
+
+## Service Orchestration: The Training Loop
+
+Now let's see how services coordinate in a real training loop:
+
+```python
+# This is the REAL way production RL systems are built with TorchForge
+
+import asyncio
+import torch
+from forge.actors.generator import Generator as Policy
+from forge.actors.reference_model import ReferenceModel
+from forge.actors.replay_buffer import ReplayBuffer
+from forge.actors.trainer import TitanTrainer
+from apps.grpo.main import DatasetActor, RewardActor, ComputeAdvantages
+from forge.data.rewards import MathReward, ThinkingReward
+
+# Service creation pattern from apps/grpo/main.py lines 322-344
+print("Initializing all services...")
+(
+ dataloader,
+ policy,
+ trainer,
+ replay_buffer,
+ compute_advantages,
+ ref_model,
+ reward_actor,
+) = await asyncio.gather(
+ DatasetActor.options(procs=1).as_actor(
+ path="openai/gsm8k", revision="main", data_split="train",
+ streaming=True, model="Qwen/Qwen3-1.7B"
+ ),
+ Policy.options(procs=1, with_gpus=True, num_replicas=1).as_service(
+ engine_config={"model": "Qwen/Qwen3-1.7B", "tensor_parallel_size": 1},
+ sampling_config={"n": 1, "max_tokens": 512}
+ ),
+ TitanTrainer.options(procs=1, with_gpus=True).as_actor(
+ model={"name": "qwen3", "flavor": "1.7B", "hf_assets_path": "hf://Qwen/Qwen3-1.7B"},
+ optimizer={"name": "AdamW", "lr": 1e-5},
+ training={"local_batch_size": 2, "seq_len": 2048}
+ ),
+ ReplayBuffer.options(procs=1).as_actor(
+ batch_size=2, max_policy_age=1, dp_size=1
+ ),
+ ComputeAdvantages.options(procs=1).as_actor(),
+ ReferenceModel.options(procs=1, with_gpus=True).as_actor(
+ model={"name": "qwen3", "flavor": "1.7B", "hf_assets_path": "hf://Qwen/Qwen3-1.7B"}
+ ),
+ RewardActor.options(procs=1, num_replicas=1).as_service(
+ reward_functions=[MathReward(), ThinkingReward()]
+ ),
+)
+
+print("All services initialized successfully!")
+
+async def production_training_loop():
+ """Real training loop pattern from apps/grpo/main.py"""
+ step = 0
+
+ while True:
+ # Data generation
+ sample = await dataloader.sample.call_one()
+
+ # Policy generation service call
+ responses = await policy.generate.route(sample["request"]) # Correct field name
+
+ # Reference computation service call (requires full input tensor)
+ input_ids = torch.cat([responses[0].prompt_ids, responses[0].token_ids])
+ ref_logprobs = await ref_model.forward.route(
+ input_ids.unsqueeze(0), max_req_tokens=512, return_logprobs=True
+ )
+
+ # Reward evaluation service call
+ reward = await reward_actor.evaluate_response.route(
+ prompt=sample["question"],
+ response=responses[0].text,
+ target=sample["answer"]
+ )
+
+ # Experience storage (using actual Episode structure)
+ episode = create_episode_from_grpo_data(sample, responses[0], reward, ref_logprobs[0], step)
+ await replay_buffer.add.call_one(episode)
+
+ # Training when ready
+ batch = await replay_buffer.sample.call_one(curr_policy_version=step)
+ if batch is not None:
+ inputs, targets = batch # GRPO returns (inputs, targets) tuple
+ loss = await trainer.train_step.call(inputs, targets)
+
+ # Weight synchronization pattern
+ await trainer.push_weights.call(step + 1)
+ await policy.update_weights.fanout(step + 1) # Fanout to all replicas
+
+ print(f"Step {step}, Loss: {loss:.4f}")
+ step += 1
+
+print("Shutting down services...")
+await asyncio.gather(
+ DatasetActor.shutdown(dataloader),
+ policy.shutdown(),
+ TitanTrainer.shutdown(trainer),
+ ReplayBuffer.shutdown(replay_buffer),
+ ComputeAdvantages.shutdown(compute_advantages),
+ ReferenceModel.shutdown(ref_model),
+ reward_actor.shutdown(),
+)
+print("All services shut down successfully!")
+```
+
+**Key observations:**
+1. **Parallelism**: Independent operations run concurrently
+2. **Load balancing**: Each `.route()` call automatically selects optimal replica
+3. **Fault tolerance**: Failures automatically retry on different replicas
+4. **Resource efficiency**: CPU and GPU services scale independently
+5. **Coordination**: Services coordinate through shared state (replay buffer, weight versions)
+
+This is the power of the service abstraction - complex distributed coordination looks like simple async Python code.
+
+In the next part we will learn about [Monarch internals](./3_Monarch_101.md)
diff --git a/docs/source/tutorial_sources/zero-to-forge/3_Monarch_101.md b/docs/source/tutorial_sources/zero-to-forge/3_Monarch_101.md
new file mode 100644
index 000000000..38faa72ee
--- /dev/null
+++ b/docs/source/tutorial_sources/zero-to-forge/3_Monarch_101.md
@@ -0,0 +1,404 @@
+# Part 3: The TorchForge-Monarch Connection
+
+This is part 3 of our series, in the previous sections: we learned Part 1: [RL Concepts and how they map to TorchForge](./1_RL_and_Forge_Fundamentals), Part 2: [TorchForge Internals](./2_Forge_Internals).
+
+Now let's peel back the layers. TorchForge services are built on top of **Monarch**, PyTorch's distributed actor framework. Understanding this connection is crucial for optimization and debugging.
+
+## The Complete Hierarchy: Service to Silicon
+
+```mermaid
+graph TD
+ subgraph YourCode["(1) Your RL Code"]
+ Call["await policy_service
+ .generate.route
+ ('What is 2+2?')"]
+ end
+
+ subgraph ForgeServices["(2) TorchForge Service Layer"]
+ ServiceInterface["ServiceInterface:
+ Routes requests
+ Load balancing
+ Health checks"]
+ ServiceActor["ServiceActor:
+ Manages replicas
+ Monitors health
+ Coordinates failures"]
+ end
+
+ subgraph MonarchLayer["(3) Monarch Actor Layer"]
+ ActorMesh["ActorMesh Policy Actor:
+ 4 instances
+ Different GPUs
+ Message passing"]
+ ProcMesh["ProcMesh:
+ 4 processes
+ GPU topology 0,1,2,3
+ Network interconnect"]
+ end
+
+ subgraph Hardware["(4) Physical Hardware"]
+ GPU0["GPU 0:
+ Policy Actor #1
+ vLLM Engine
+ Model Weights"]
+ GPU1["GPU 1:
+ Policy Actor #2
+ vLLM Engine
+ Model Weights"]
+ GPU2["GPU 2:
+ Policy Actor #3
+ vLLM Engine
+ Model Weights"]
+ GPU3["GPU 3:
+ Policy Actor #4
+ vLLM Engine
+ Model Weights"]
+ end
+
+ Call --> ServiceInterface
+ ServiceInterface --> ServiceActor
+ ServiceActor --> ActorMesh
+ ActorMesh --> ProcMesh
+ ProcMesh --> GPU0
+ ProcMesh --> GPU1
+ ProcMesh --> GPU2
+ ProcMesh --> GPU3
+
+ style Call fill:#4CAF50
+ style ServiceActor fill:#FF9800
+ style ActorMesh fill:#9C27B0
+ style ProcMesh fill:#2196F3
+```
+
+## Deep Dive: ProcMesh - The Foundation
+
+**ProcMesh** is Monarch's core abstraction for organizing processes across hardware. Think of it as a multi-dimensional grid that maps directly to your cluster topology.
+
+### Single Host ProcMesh
+
+**Key insight**: ProcMesh creates one process per GPU, automatically handling the process-to-hardware mapping.
+
+```python
+# This simple call:
+procs = this_host().spawn_procs(per_host={"gpus": 8})
+
+# Creates:
+# Process 0 → GPU 0
+# Process 1 → GPU 1
+# Process 2 → GPU 2
+# Process 3 → GPU 3
+# Process 4 → GPU 4
+# Process 5 → GPU 5
+# Process 6 → GPU 6
+# Process 7 → GPU 7
+```
+
+The beauty: you don't manage individual processes or GPU assignments - ProcMesh handles the topology for you.
+
+### Multi-Host ProcMesh
+
+**Key insight**: ProcMesh seamlessly scales across multiple hosts with continuous process numbering.
+
+```python
+# Same simple API works across hosts:
+cluster_procs = spawn_cluster_procs(
+ hosts=["host1", "host2", "host3"],
+ per_host={"gpus": 4}
+)
+
+# Automatically creates:
+# Host 1: Processes 0-3 → GPUs 0-3
+# Host 2: Processes 4-7 → GPUs 0-3
+# Host 3: Processes 8-11 → GPUs 0-3
+
+# Your code stays the same whether it's 1 host or 100 hosts
+actors = cluster_procs.spawn("my_actor", MyActor)
+```
+
+**The power**: Scale from single host to cluster without changing your actor code - ProcMesh handles all the complexity.
+
+```python
+# This shows the underlying actor system that powers TorchForge services
+# NOTE: This is for educational purposes - use ForgeActor and .as_service() in real TorchForge apps!
+
+from monarch.actor import Actor, endpoint, this_proc, Future
+from monarch.actor import ProcMesh, this_host
+import asyncio
+
+# STEP 1: Define a basic actor
+class Counter(Actor):
+ def __init__(self, initial_value: int):
+ self.value = initial_value
+
+ @endpoint
+ def increment(self) -> None:
+ self.value += 1
+
+ @endpoint
+ def get_value(self) -> int:
+ return self.value
+
+# STEP 2: Single actor in local process
+counter: Counter = this_proc().spawn("counter", Counter, initial_value=0)
+
+# STEP 3: Send messages
+fut: Future[int] = counter.get_value.call_one()
+value = await fut
+print(f"Counter value: {value}") # 0
+
+# STEP 4: Multiple actors across processes
+procs: ProcMesh = this_host().spawn_procs(per_host={"gpus": 8})
+counters: Counter = procs.spawn("counters", Counter, 0)
+
+# STEP 5: Broadcast to all actors
+await counters.increment.call()
+
+# STEP 6: Different message patterns
+# call_one() - single actor
+value = await counters.get_value.call_one()
+print(f"One counter: {value}") # Output: One counter: 1
+
+# choose() - random single actor (actors only, not services)
+value = await counters.get_value.choose()
+print(f"Random counter: {value}") # Output: Random counter: 1
+
+# call() - all actors, collect results
+values = await counters.get_value.call()
+print(f"All counters: {values}") # Output: All counters: [1, 1, 1, 1, 1, 1, 1, 1]
+
+# broadcast() - fire and forget
+await counters.increment.broadcast() # No return value - just sends to all actors
+
+# Cleanup
+await procs.stop()
+
+# Remember: This raw Monarch code is for understanding how TorchForge works internally.
+# In your TorchForge applications, use ForgeActor, .as_service(), .as_actor() instead!
+```
+
+## Actor Meshes: Your Code Running Distributed
+
+**ActorMesh** is created when you spawn actors across a ProcMesh. Key points:
+
+- **One actor instance per process**: `mesh.spawn("policy", Policy)` creates one Policy Actor in each process
+- **Same constructor arguments**: All instances get the same initialization parameters
+- **Independent state**: Each actor instance maintains its own state and memory
+- **Message routing**: You can send messages to one actor or all actors using different methods
+
+```python
+# Simple example:
+procs = spawn_procs(per_host={"gpus": 4}) # 4 processes
+policy_actors = procs.spawn("policy", Policy, model="Qwen/Qwen3-7B")
+
+# Now you have 4 Policy Actor instances, one per GPU
+# All initialized with the same model parameter
+```
+
+## How TorchForge Services Use Monarch
+
+Now the key insight: **TorchForge services are ServiceActors that manage ActorMeshes of your ForgeActor replicas**.
+
+### The Service Creation Process
+
+```mermaid
+graph TD
+ subgraph ServiceCreation["Service Creation Process"]
+ Call["await Policy
+ .options(
+ num_replicas=4,
+ procs=1)
+ .as_service(
+ model='Qwen')"]
+
+ ServiceActor["ServiceActor:
+ Manages 4 replicas
+ Health checks
+ Routes calls"]
+
+ subgraph Replicas["4 Independent Replicas"]
+ subgraph R0["Replica 0"]
+ PM0["ProcMesh:
+ 1 process
+ GPU 0"]
+ AM0["ActorMesh
+ 1 Policy Actor"]
+ end
+
+ subgraph R1["Replica 1"]
+ PM1["ProcMesh:
+ 1 process
+ GPU 1"]
+ AM1["ActorMesh
+ 1 Policy Actor"]
+ end
+
+ subgraph R2["Replica 2"]
+ PM2["ProcMesh:
+ 1 process
+ GPU 2"]
+ AM2["ActorMesh
+ 1 Policy Actor"]
+ end
+
+ subgraph R3["Replica 3"]
+ PM3["ProcMesh:
+ 1 process
+ GPU 3"]
+ AM3["ActorMesh
+ 1 Policy Actor"]
+ end
+ end
+
+ Call --> ServiceActor
+ ServiceActor --> R0
+ ServiceActor --> R1
+ ServiceActor --> R2
+ ServiceActor --> R3
+ PM0 --> AM0
+ PM1 --> AM1
+ PM2 --> AM2
+ PM3 --> AM3
+ end
+
+ style ServiceActor fill:#FF9800
+ style AM0 fill:#4CAF50
+ style AM1 fill:#4CAF50
+ style AM2 fill:#4CAF50
+ style AM3 fill:#4CAF50
+```
+
+### Service Call to Actor Execution
+
+```mermaid
+:align: center
+graph TD
+ subgraph CallFlow["Complete Call Flow"]
+ UserCall["await policy_service
+ .generate.route
+ ('What is 2+2?')"]
+
+ ServiceInterface["ServiceInterface:
+ Receives .route() call
+ Routes to ServiceActor"]
+
+ ServiceActor["ServiceActor:
+ Selects healthy replica
+ Load balancing
+ Failure handling"]
+
+ SelectedReplica["Selected Replica #2:
+ ProcMesh 1 process
+ ActorMesh 1 Policy Actor"]
+
+ PolicyActor["Policy Actor Instance:
+ Loads model
+ Runs vLLM inference"]
+
+ GPU["GPU 2:
+ vLLM engine
+ Model weights
+ KV cache
+ CUDA kernels"]
+
+ UserCall --> ServiceInterface
+ ServiceInterface --> ServiceActor
+ ServiceActor --> SelectedReplica
+ SelectedReplica --> PolicyActor
+ PolicyActor --> GPU
+
+ GPU -.->|"Response"| PolicyActor
+ PolicyActor -.->|"Response"| SelectedReplica
+ SelectedReplica -.->|"Response"| ServiceActor
+ ServiceActor -.->|"Response"| ServiceInterface
+ ServiceInterface -.->|"'The answer is 4'"| UserCall
+ end
+
+ style UserCall fill:#4CAF50
+ style ServiceActor fill:#FF9800
+ style PolicyActor fill:#9C27B0
+ style GPU fill:#FF5722
+```
+
+## Multiple Services Sharing Infrastructure
+
+In real RL systems, you have multiple services that can share or use separate ProcMeshes:
+
+```mermaid
+graph TD
+ subgraph Cluster["RL Training Cluster"]
+ subgraph Services["TorchForge Services"]
+ PS["Policy Service - 4 GPU replicas"]
+ TS["Trainer Service - 2 GPU replicas"]
+ RS["Reward Service - 4 CPU replicas"]
+ BS["Buffer Service - 1 CPU replica"]
+ end
+
+ subgraph MonarchInfra["Monarch Infrastructure"]
+ subgraph GPUMesh["GPU ProcMesh (6 processes)"]
+ G0["Process 0 - GPU 0"]
+ G1["Process 1 - GPU 1"]
+ G2["Process 2 - GPU 2"]
+ G3["Process 3 - GPU 3"]
+ G4["Process 4 - GPU 4"]
+ G5["Process 5 - GPU 5"]
+ end
+
+ subgraph CPUMesh["CPU ProcMesh (5 processes)"]
+ C0["Process 0 - CPU"]
+ C1["Process 1 - CPU"]
+ C2["Process 2 - CPU"]
+ C3["Process 3 - CPU"]
+ C4["Process 4 - CPU"]
+ end
+ end
+
+ PS --> G0
+ PS --> G1
+ PS --> G2
+ PS --> G3
+ TS --> G4
+ TS --> G5
+ RS --> C0
+ RS --> C1
+ RS --> C2
+ RS --> C3
+ BS --> C4
+ end
+
+ style PS fill:#4CAF50
+ style TS fill:#E91E63
+ style RS fill:#FF9800
+ style BS fill:#9C27B0
+ style GPUMesh fill:#FFEBEE
+ style CPUMesh fill:#E3F2FD
+```
+
+## Key Insights: Why This Architecture Matters
+
+1. **Process Isolation**: Each actor runs in its own process - failures don't cascade
+2. **Location Transparency**: Actors can be local or remote with identical APIs
+3. **Structured Distribution**: ProcMesh maps directly to hardware topology
+4. **Message Passing**: No shared memory means no race conditions or locks
+5. **Service Abstraction**: TorchForge hides Monarch complexity while preserving power
+
+Understanding this hierarchy helps you:
+- **Debug performance issues**: Is the bottleneck at service, actor, or hardware level?
+- **Optimize resource usage**: How many replicas per service? GPU vs CPU processes?
+- **Handle failures gracefully**: Which layer failed and how to recover?
+- **Scale effectively**: Where to add resources for maximum impact?
+
+# Conclusion
+
+## What You've Learned
+
+1. **RL Fundamentals**: How RL concepts map to TorchForge services with REAL, working examples
+2. **Service Abstraction**: How to use TorchForge services effectively with verified communication patterns
+3. **Monarch Foundation**: How TorchForge services connect to distributed actors and hardware
+
+## Key Takeaways
+
+- **Services hide complexity**: Your RL code looks like simple async functions, but runs on distributed clusters
+- **Communication patterns matter**: `.route()`, `.fanout()`, sessions, and `.call_one()` each serve specific purposes
+- **Architecture understanding helps**: Knowing the Service → Actor → Process → Hardware hierarchy helps you debug, optimize, and scale
+- **Always verify APIs**: This guide is verified, but cross-check with source code for latest changes
+- **Real API patterns**: Use `.options().as_service()` not `spawn_service()`, use `.route()` not `.choose()`, etc.
diff --git a/docs/source/tutorials.md b/docs/source/tutorials.md
index 5cfd3dbaf..662cb66e5 100644
--- a/docs/source/tutorials.md
+++ b/docs/source/tutorials.md
@@ -2,3 +2,10 @@
This section provides step-by-step guides to help you master TorchForge's capabilities,
from basic model fine-tuning to advanced distributed training scenarios.
+
+```{toctree}
+:maxdepth: 1
+
+zero-to-forge-intro
+metric_logging
+```
diff --git a/docs/source/usage.md b/docs/source/usage.md
deleted file mode 100644
index 2a61c577a..000000000
--- a/docs/source/usage.md
+++ /dev/null
@@ -1,4 +0,0 @@
-# Usage
-
-This section covers practical usage patterns,
-configuration management, and common scenarios for working with TorchForge effectively.
diff --git a/docs/source/zero-to-forge-intro.md b/docs/source/zero-to-forge-intro.md
new file mode 100644
index 000000000..7e815c83b
--- /dev/null
+++ b/docs/source/zero-to-forge-intro.md
@@ -0,0 +1,29 @@
+# Zero to TorchForge: From RL Theory to Production-Scale Implementation
+
+A comprehensive guide for ML Engineers building distributed RL systems for language models.
+
+Some of the examples mentioned below will be conceptual in nature for understanding.
+Please refer to [API Docs](./api) for more details.
+
+Welcome to the Tutorials section! This section is inspired by the A-Z
+PyTorch tutorial, shoutout to our PyTorch friends that remember!
+
+## Tutorial Structure
+
+This section currently is structured in 3 detailed parts:
+
+1. [Part 1: RL Fundamentals - Using TorchForge Terminology](tutorials/zero-to-forge/1_RL_and_Forge_Fundamentals): This gives a quick refresher of Reinforcement Learning and teaches you TorchForge Fundamentals
+2. [Part 2: Peeling Back the Abstraction - What Are Services?](tutorials/zero-to-forge/2_Forge_Internals): Goes a layer deeper and explains the internals of TorchForge
+3. [Part 3: The TorchForge-Monarch Connection](tutorials/zero-to-forge/3_Monarch_101): It's a 101 to Monarch and how TorchForge Talks to Monarch
+
+Each part builds upon the next and the entire section can be consumed in roughly an hour - Grab a Chai and Enjoy!
+
+If you're eager, please checkout our SFT Tutorial too (Coming soon!)!
+
+```{toctree}
+:maxdepth: 1
+:hidden:
+tutorials/zero-to-forge/1_RL_and_Forge_Fundamentals
+tutorials/zero-to-forge/2_Forge_Internals
+tutorials/zero-to-forge/3_Monarch_101
+```
diff --git a/launcher/job.sbatch b/launcher/job.sbatch
deleted file mode 100644
index 58695c5ee..000000000
--- a/launcher/job.sbatch
+++ /dev/null
@@ -1,19 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-#!/bin/bash
-#SBATCH --job-name=forge
-#SBATCH --output=slogs/forge.out
-#SBATCH --error=slogs/forge.err
-#SBATCH --partition=h100-low # or h100-high / h100-prod / all
-#SBATCH --nodes=1 # 1 node
-#SBATCH --ntasks=1 # 1 task (process)
-#SBATCH --gres=gpu:8 # request 8 GPUs
-#SBATCH --time=01:00:00 # walltime hh:mm:ss
-
-unset SLURM_MEM_PER_CPU SLURM_MEM_PER_GPU SLURM_MEM_PER_NODE
-echo "Running on $SLURM_JOB_NODELIST"
-python -m apps.grpo.main --config=apps/grpo/multihost.yaml
diff --git a/pyproject.toml b/pyproject.toml
index 886ed672c..48bd8c238 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -11,11 +11,13 @@ authors = [
keywords = ["pytorch", "training", "llm"]
dependencies = [
# PyTorch
+ "torch==2.9.0",
"torchdata>=0.8.0",
- "torchtitan",
+ "torchtitan==0.2.0",
+ "torchmonarch==0.1.2",
+ "torchstore==0.1.2",
# vLLM
- # TODO: pin specific vllm version
- #"vllm==0.10.0",
+ "vllm",
# Hugging Face integrations
"datasets>=2.21.0",
"tokenizers",
@@ -23,13 +25,15 @@ dependencies = [
"omegaconf",
"wandb",
"hf_transfer",
+ "six",
+ "setuptools<80",
]
dynamic = ["version"]
[project.urls]
-GitHub = "https://github.com/pytorch-labs/forge"
-Documentation = "https://github.com/pytorch-labs/forge/tree/main/docs"
-Issues = "https://github.com/pytorch-labs/forge/issues"
+GitHub = "https://github.com/meta-pytorch/torchforge"
+Documentation = "https://meta-pytorch.org/torchforge"
+Issues = "https://github.com/meta-pytorch/torchforge/issues"
[project.optional-dependencies]
dev = [
@@ -42,17 +46,21 @@ dev = [
"tomli>=1.1.0",
"anyio",
"pytest-asyncio",
+ "multiprocess",
+ "langid",
]
-oss = [
- "torch",
- "torchmonarch-nightly==2025.8.1",
- "torchstore",
+docs = [
+ "sphinx==7.2.6",
+ "pytorch-sphinx-theme2==0.1.0",
+ "docutils>=0.18.1,<0.21",
+ "sphinx-design==0.6.1",
+ "sphinxcontrib-mermaid==1.0.0",
+ "sphinx-gallery==0.19.0",
+ "myst-parser",
+ "sphinx-sitemap==2.7.1",
+ "sphinx-autodoc-typehints==1.25.3",
]
-[project.scripts]
-forge = "forge.cli.forge:main"
-
-
# ---- Explicit project build information ---- #
[build-system]
requires = ["setuptools>=61.0"]
@@ -72,23 +80,18 @@ members = [
]
# pytorch
-# TODO: get auto backend to work
[[tool.uv.index]]
-name = "pytorch-nightly-cu129"
-url = "https://download.pytorch.org/whl/nightly/cu129"
-#explicit = true
+name = "pytorch-cu128"
+url = "https://download.pytorch.org/whl/cu128"
# vllm
-# [[tool.uv.index]]
-# name = "vllm-nightly"
-# url = "https://wheels.vllm.ai/nightly"
-# explicit = true
+[[tool.uv.index]]
+name = "vllm-forge"
+url = "https://download.pytorch.org/whl/preview/forge"
[tool.uv.sources]
-torchtitan = { index = "pytorch-nightly-cu129" }
-torch = { index = "pytorch-nightly-cu129" }
-torchstore = { git = "ssh://git@github.com/meta-pytorch/torchstore.git" }
-#vllm = { index = "vllm-nightly" }
+torch = { index = "pytorch-cu128" }
+vllm = { index = "vllm-forge" }
[tool.uv]
# TODO: revert to stricter default uv strategy
@@ -98,4 +101,9 @@ prerelease = "allow"
environments = [
"sys_platform == 'linux'",
]
-# override-dependencies = ["torch>2.7.1", "torchaudio>=2.7.1", "torchvision>=0.22.0"]
+
+[tool.black]
+target-version = ["py310"] # match the minium supported python version
+
+[tool.usort]
+first_party_detection = false
diff --git a/scripts/build_wheels.sh b/scripts/build_wheels.sh
deleted file mode 100755
index 4c1f900dd..000000000
--- a/scripts/build_wheels.sh
+++ /dev/null
@@ -1,353 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-#!/bin/bash
-set -euo pipefail
-
-# Colors for output
-RED='\033[0;31m'
-GREEN='\033[0;32m'
-YELLOW='\033[1;33m'
-BLUE='\033[0;34m'
-NC='\033[0m'
-
-# Configuration
-PYTORCH_VERSION="2.9.0.dev20250905"
-VLLM_BRANCH="v0.10.0"
-MONARCH_COMMIT="16e3de376b22b5c44ee3853af5576e4998ea74bf"
-TORCHTITAN_COMMIT="0cfbd0b3c2d827af629a107a77a9e47229c31663"
-TORCHSTORE_COMMIT="eed96eb55ce87d4a9880597dd7dfd0d291e9ac81"
-BUILD_DIR="$HOME/forge-build"
-WHEEL_DIR="$(pwd)/assets/wheels"
-
-# Logging functions
-log_info() { echo -e "${GREEN}[INFO]${NC} $1"; }
-log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
-log_error() { echo -e "${RED}[ERROR]${NC} $1"; }
-log_step() { echo -e "${BLUE}[$1/$2]${NC} $3"; }
-
-# Track current step for recovery
-CURRENT_STEP=0
-TOTAL_STEPS=8
-
-# Function to handle step failures
-handle_failure() {
- local step_name="$1"
- local step_num="$2"
- local exit_code="$3"
- local retry_cmd="$4"
-
- log_error "Step $step_num failed: $step_name"
- log_error "Exit code: $exit_code"
- log_error "Working directory: $(pwd)"
- echo ""
- log_info "To retry this step manually:"
- echo " $retry_cmd"
- echo ""
- log_info "Or to resume from step $step_num:"
- echo " $0 --resume-from=$step_num"
- echo ""
- exit $exit_code
-}
-
-# Validation functions
-check_conda_env() {
- if [ -z "${CONDA_DEFAULT_ENV:-}" ]; then
- log_error "Not running in a conda environment"
- log_info "Please create and activate your conda environment first:"
- log_info " conda create -n forge python=3.10 -y"
- log_info " conda activate forge"
- exit 1
- fi
- log_info "Running in conda environment: $CONDA_DEFAULT_ENV"
-}
-
-check_command() {
- if ! command -v "$1" &> /dev/null; then
- log_error "Required command '$1' not found"
- exit 1
- fi
-}
-
-check_sudo() {
- if ! sudo -n true 2>/dev/null; then
- log_error "This script requires passwordless sudo access"
- log_info "Run 'sudo -v' first, or configure passwordless sudo"
- exit 1
- fi
-}
-
-check_disk_space() {
- local required_gb=10
- local available_gb=$(df ~/ --output=avail -BG | tail -1 | sed 's/G//')
- if [ "$available_gb" -lt "$required_gb" ]; then
- log_error "Insufficient disk space. Need ${required_gb}GB, have ${available_gb}GB"
- exit 1
- fi
-}
-
-# Main validation
-validate_environment() {
- log_info "Validating environment..."
-
- check_conda_env
- check_command git
- check_command curl
- check_command python
- check_command pip
- check_command conda
- check_sudo
- check_disk_space
-
- # Check if CUDA toolkit will be available
- if ! ldconfig -p | grep -q cuda; then
- log_warn "CUDA libraries not found in ldconfig. Will attempt to install CUDA toolkit."
- fi
-
- log_info "Environment validation passed"
-}
-
-# Setup build directory and wheels directory
-setup_build_dir() {
- log_info "Setting up build directory: $BUILD_DIR"
- mkdir -p "$BUILD_DIR"
- log_info "Setting up wheels directory: $WHEEL_DIR"
- mkdir -p "$WHEEL_DIR"
- log_info "Build and wheels directories created"
-}
-
-# Setup CUDA environment variables
-setup_cuda_env() {
- log_info "Setting up CUDA environment..."
-
- export CUDA_VERSION=12.9
- export NVCC=/usr/local/cuda-${CUDA_VERSION}/bin/nvcc
- export CUDA_NVCC_EXECUTABLE=/usr/local/cuda-${CUDA_VERSION}/bin/nvcc
- export CUDA_HOME=/usr/local/cuda-${CUDA_VERSION}
- export PATH="${CUDA_HOME}/bin:$PATH"
- export CUDA_INCLUDE_DIRS=$CUDA_HOME/include
- export CUDA_CUDART_LIBRARY=$CUDA_HOME/lib64/libcudart.so
- export LD_LIBRARY_PATH=/usr/local/cuda-12.9/compat:${LD_LIBRARY_PATH:-}
- export LIBRARY_PATH=$CUDA_HOME/lib64:${LIBRARY_PATH:-}
-
- # Save to file for persistence
- cat > ~/.forge_cuda_env << 'EOF'
-export CUDA_VERSION=12.9
-export NVCC=/usr/local/cuda-${CUDA_VERSION}/bin/nvcc
-export CUDA_NVCC_EXECUTABLE=/usr/local/cuda-${CUDA_VERSION}/bin/nvcc
-export CUDA_HOME=/usr/local/cuda-${CUDA_VERSION}
-export PATH="${CUDA_HOME}/bin:$PATH"
-export CUDA_INCLUDE_DIRS=$CUDA_HOME/include
-export CUDA_CUDART_LIBRARY=$CUDA_HOME/lib64/libcudart.so
-export LD_LIBRARY_PATH=/usr/local/cuda-12.9/compat:${LD_LIBRARY_PATH:-}
-export LIBRARY_PATH=${CUDA_HOME}/lib64:${LIBRARY_PATH:-}
-EOF
-
- log_info "CUDA environment configured"
-}
-
-# Parse command line arguments
-RESUME_FROM=1
-while [[ $# -gt 0 ]]; do
- case $1 in
- --resume-from=*)
- RESUME_FROM="${1#*=}"
- shift
- ;;
- *)
- log_error "Unknown argument: $1"
- exit 1
- ;;
- esac
-done
-
-# Step execution wrapper
-run_step() {
- local step_num="$1"
- local step_name="$2"
- local step_function="$3"
-
- if [ "$step_num" -lt "$RESUME_FROM" ]; then
- log_info "Skipping step $step_num: $step_name (resuming from step $RESUME_FROM)"
- return 0
- fi
-
- CURRENT_STEP=$step_num
- log_step "$step_num" "$TOTAL_STEPS" "$step_name"
-
- if ! $step_function; then
- handle_failure "$step_name" "$step_num" "$?" "Run step $step_num manually"
- fi
-}
-
-# Step 1: Install PyTorch nightly
-step1_pytorch() {
- pip3 install --pre torch==$PYTORCH_VERSION --index-url https://download.pytorch.org/whl/nightly/cu129
-}
-
-# Step 2: Install CUDA system packages
-step2_cuda_packages() {
- sudo dnf install -y cuda-toolkit-12-9 cuda-compat-12-9
- setup_cuda_env
-}
-
-# Step 3: Build vLLM wheel
-step3_vllm() {
- cd "$BUILD_DIR"
- if [ -d "vllm" ]; then
- log_warn "vLLM directory exists, removing..."
- rm -rf vllm
- fi
-
- git clone https://github.com/vllm-project/vllm.git --branch $VLLM_BRANCH
- cd "$BUILD_DIR/vllm"
-
- python use_existing_torch.py
- pip install -r requirements/build.txt
- pip wheel --no-build-isolation --no-deps . -w "$WHEEL_DIR"
-}
-
-# Step 4: Setup Rust toolchain
-step4_rust_setup() {
- # Install Rust if not present
- if ! command -v rustup &> /dev/null; then
- curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
- source ~/.cargo/env
- fi
-
- rustup toolchain install nightly
- rustup default nightly
-
- # Install additional system packages
- conda install -y libunwind
- sudo dnf install -y clang-devel libnccl-devel
- sudo dnf install -y libibverbs rdma-core libmlx5 libibverbs-devel rdma-core-devel
-}
-
-# Step 5: Build Monarch wheel
-step5_monarch() {
- cd "$BUILD_DIR"
- if [ -d "monarch" ]; then
- log_warn "Monarch directory exists, removing..."
- rm -rf monarch
- fi
-
- git clone https://github.com/meta-pytorch/monarch.git
- cd "$BUILD_DIR/monarch"
- git checkout $MONARCH_COMMIT
-
- pip install -r build-requirements.txt
- pip wheel --no-build-isolation --no-deps . -w "$WHEEL_DIR"
-}
-
-# Step 6: Build torchtitan wheel
-step6_torchtitan() {
- cd "$BUILD_DIR"
- if [ -d "torchtitan" ]; then
- log_warn "torchtitan directory exists, removing..."
- rm -rf torchtitan
- fi
-
- git clone https://github.com/pytorch/torchtitan.git
- cd "$BUILD_DIR/torchtitan"
- git checkout $TORCHTITAN_COMMIT
-
- pip wheel --no-deps . -w "$WHEEL_DIR"
-}
-
-# Step 7: Build torchstore wheel
-step7_torchstore() {
- cd "$BUILD_DIR"
- if [ -d "torchstore" ]; then
- log_warn "torchstore directory exists, removing..."
- rm -rf torchstore
- fi
-
- git clone https://github.com/meta-pytorch/torchstore.git
- cd "$BUILD_DIR/torchstore"
- git checkout $TORCHSTORE_COMMIT
-
- pip wheel --no-deps . -w "$WHEEL_DIR"
-}
-
-# Verification
-verify_installation() {
- log_info "Verifying wheel builds..."
-
- python -c "import torch; print(f'PyTorch {torch.__version__} (CUDA: {torch.cuda.is_available()})')"
-
- # Check that wheels were created
- wheel_count=$(ls -1 "$WHEEL_DIR"/*.whl 2>/dev/null | wc -l)
- if [ "$wheel_count" -gt 0 ]; then
- log_info "Built $wheel_count wheels:"
- ls -1 "$WHEEL_DIR"/*.whl | sed 's/.*\// /'
- else
- log_error "No wheels found in $WHEEL_DIR"
- return 1
- fi
-
- log_info "Wheel building verification complete!"
-}
-
-# Main execution
-main() {
- echo "Forge Wheel Builder"
- echo "==================="
- echo ""
-
- if [ "$RESUME_FROM" -gt 1 ]; then
- log_info "Resuming from step $RESUME_FROM..."
- # Source CUDA env if resuming
- if [ -f ~/.forge_cuda_env ]; then
- source ~/.forge_cuda_env
- fi
- # Source Rust env if resuming
- if [ -f ~/.cargo/env ]; then
- source ~/.cargo/env
- fi
- else
- validate_environment
- setup_build_dir
- fi
-
- run_step 1 "Installing PyTorch nightly" step1_pytorch
- run_step 2 "Installing CUDA packages and setting environment" step2_cuda_packages
- run_step 3 "Building vLLM wheel" step3_vllm
- run_step 4 "Setting up Rust toolchain and additional packages" step4_rust_setup
- run_step 5 "Building Monarch wheel" step5_monarch
- run_step 6 "Building torchtitan wheel" step6_torchtitan
- run_step 7 "Building torchstore wheel" step7_torchstore
-
- verify_installation
-
- echo ""
- log_info "Wheel building completed successfully!"
- log_info ""
- log_info "Built wheels are in: $WHEEL_DIR"
- log_info ""
- log_info "Users can now install with:"
- log_info " conda create -n forge python=3.10 -y"
- log_info " conda activate forge"
- log_info " pip install torch==$PYTORCH_VERSION --index-url https://download.pytorch.org/whl/nightly/cu129"
- log_info " pip install $WHEEL_DIR/*.whl"
- log_info " source ~/.forge_cuda_env"
- log_info ""
- log_info "Build artifacts are in: $BUILD_DIR"
- log_info "You can remove them with: rm -rf $BUILD_DIR"
-}
-
-# Set trap for cleanup on failure
-cleanup() {
- if [ $? -ne 0 ] && [ $CURRENT_STEP -gt 0 ]; then
- echo ""
- log_error "Setup failed at step $CURRENT_STEP"
- log_info "You can resume with: $0 --resume-from=$CURRENT_STEP"
- fi
-}
-trap cleanup EXIT
-
-# Run main function
-main "$@"
diff --git a/scripts/generate_vllm_reqs.sh b/scripts/generate_vllm_reqs.sh
new file mode 100755
index 000000000..6da96c200
--- /dev/null
+++ b/scripts/generate_vllm_reqs.sh
@@ -0,0 +1,183 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#!/bin/bash
+set -euo pipefail
+
+# Colors for output
+RED='\033[0;31m'
+GREEN='\033[0;32m'
+YELLOW='\033[1;33m'
+NC='\033[0m'
+
+# Source version configuration
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+VERSIONS_FILE="$SCRIPT_DIR/../assets/versions.sh"
+
+if [ ! -f "$VERSIONS_FILE" ]; then
+ echo -e "${RED}[ERROR]${NC} Versions file not found: $VERSIONS_FILE"
+ exit 1
+fi
+
+source "$VERSIONS_FILE"
+
+# Configuration
+BUILD_DIR="$HOME/forge-build"
+WHEEL_DIR="$(pwd)/assets/wheels"
+
+# Logging functions
+log_info() { echo -e "${GREEN}[INFO]${NC} $1"; }
+log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
+log_error() { echo -e "${RED}[ERROR]${NC} $1"; }
+
+
+# Validation functions
+check_conda_env() {
+ if [ -z "${CONDA_DEFAULT_ENV:-}" ]; then
+ log_error "Not running in a conda environment"
+ log_info "Please create and activate your conda environment first:"
+ log_info " conda create -n forge python=3.10 -y"
+ log_info " conda activate forge"
+ exit 1
+ fi
+ log_info "Running in conda environment: $CONDA_DEFAULT_ENV"
+}
+
+check_command() {
+ if ! command -v "$1" &> /dev/null; then
+ log_error "Required command '$1' not found"
+ exit 1
+ fi
+}
+
+check_sudo() {
+ if ! sudo -n true 2>/dev/null; then
+ log_error "This script requires passwordless sudo access"
+ log_info "Run 'sudo -v' first, or configure passwordless sudo"
+ exit 1
+ fi
+}
+
+check_disk_space() {
+ local required_gb=10
+ local available_gb=$(df ~/ --output=avail -BG | tail -1 | sed 's/G//')
+ if [ "$available_gb" -lt "$required_gb" ]; then
+ log_error "Insufficient disk space. Need ${required_gb}GB, have ${available_gb}GB"
+ exit 1
+ fi
+}
+
+# Main validation
+validate_environment() {
+ log_info "Validating environment..."
+
+ check_conda_env
+ check_command git
+ check_command curl
+ check_command python
+ check_command pip
+ check_command conda
+ check_sudo
+ check_disk_space
+
+ # Check if CUDA toolkit will be available
+ if ! ldconfig -p | grep -q cuda; then
+ log_warn "CUDA libraries not found in ldconfig. Will attempt to install CUDA toolkit."
+ fi
+
+ log_info "Environment validation passed"
+}
+
+# Setup build directory and wheels directory
+setup_build_dir() {
+ log_info "Setting up build directory: $BUILD_DIR"
+ mkdir -p "$BUILD_DIR"
+ log_info "Setting up wheels directory: $WHEEL_DIR"
+ mkdir -p "$WHEEL_DIR"
+ log_info "Build and wheels directories created"
+}
+
+# Setup CUDA environment variables
+setup_cuda_env() {
+ log_info "Setting up CUDA environment..."
+
+ export CUDA_VERSION=12.8
+ export NVCC=/usr/local/cuda-${CUDA_VERSION}/bin/nvcc
+ export CUDA_NVCC_EXECUTABLE=/usr/local/cuda-${CUDA_VERSION}/bin/nvcc
+ export CUDA_HOME=/usr/local/cuda-${CUDA_VERSION}
+ export PATH="${CUDA_HOME}/bin:$PATH"
+ export CUDA_INCLUDE_DIRS=$CUDA_HOME/include
+ export CUDA_CUDART_LIBRARY=$CUDA_HOME/lib64/libcudart.so
+ export LD_LIBRARY_PATH=/usr/local/cuda-12.8/compat:${LD_LIBRARY_PATH:-}
+ export LIBRARY_PATH=$CUDA_HOME/lib64:${LIBRARY_PATH:-}
+
+ # Save to file for persistence
+ cat > ~/.forge_cuda_env << 'EOF'
+export CUDA_VERSION=12.8
+export NVCC=/usr/local/cuda-${CUDA_VERSION}/bin/nvcc
+export CUDA_NVCC_EXECUTABLE=/usr/local/cuda-${CUDA_VERSION}/bin/nvcc
+export CUDA_HOME=/usr/local/cuda-${CUDA_VERSION}
+export PATH="${CUDA_HOME}/bin:$PATH"
+export CUDA_INCLUDE_DIRS=$CUDA_HOME/include
+export CUDA_CUDART_LIBRARY=$CUDA_HOME/lib64/libcudart.so
+export LD_LIBRARY_PATH=/usr/local/cuda-12.8/compat:${LD_LIBRARY_PATH:-}
+export LIBRARY_PATH=${CUDA_HOME}/lib64:${LIBRARY_PATH:-}
+EOF
+
+ log_info "CUDA environment configured"
+}
+
+# Step 1: Install PyTorch stable
+step1_pytorch() {
+ pip3 install --pre torch==$PYTORCH_VERSION --index-url https://download.pytorch.org/whl/cu128
+}
+
+# Step 2: Install CUDA system packages
+step2_cuda_packages() {
+ sudo dnf install -y cuda-toolkit-12-8 cuda-compat-12-8
+ setup_cuda_env
+}
+
+# Step 3: Build vLLM wheel
+step3_vllm() {
+ log_info "Building vLLM from branch: $VLLM_VERSION (from $VERSIONS_FILE)"
+ cd "$BUILD_DIR"
+ if [ -d "vllm" ]; then
+ log_warn "vLLM directory exists, removing..."
+ rm -rf vllm
+ fi
+
+ git clone https://github.com/vllm-project/vllm.git --branch $VLLM_VERSION
+ cd "$BUILD_DIR/vllm"
+
+ python use_existing_torch.py
+ pip install -r requirements/build.txt
+ pip install --no-build-isolation -e .
+}
+
+# Main execution
+main() {
+ echo "Forge Wheel Builder"
+ echo "==================="
+ echo ""
+
+ validate_environment
+ setup_build_dir
+
+ # Install PyTorch, CUDA packages, and vLLM
+ step1_pytorch
+ step2_cuda_packages
+ step3_vllm
+
+ # Output requirements to .github/packaging/vllm_reqs_12_8.txt
+ REQS_FILE="$SCRIPT_DIR/../.github/packaging/vllm_reqs_12_8.txt"
+ pip freeze | grep -v "vllm*" > $REQS_FILE
+ sed -i '1i# This file was generated by running ./scripts/generate_vllm_reqs.sh' $REQS_FILE
+}
+
+
+# Run main function
+main "$@"
diff --git a/scripts/install.sh b/scripts/install.sh
index eb4776cfd..ba15699cf 100755
--- a/scripts/install.sh
+++ b/scripts/install.sh
@@ -18,8 +18,22 @@ log_error() { echo -e "${RED}[ERROR]${NC} $1"; }
log_warning() { echo -e "${YELLOW}[WARNING]${NC} $1";}
# Configuration
-PYTORCH_VERSION="2.9.0.dev20250905"
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+VERSIONS_FILE="$SCRIPT_DIR/../assets/versions.sh"
+
+if [ ! -f "$VERSIONS_FILE" ]; then
+ log_error "Versions file not found: $VERSIONS_FILE"
+ exit 1
+fi
+
+source "$VERSIONS_FILE"
+
+# Validate required variables are set
+if [ -z "${PYTORCH_VERSION:-}" ]; then
+ log_error "PYTORCH_VERSION not set in $VERSIONS_FILE"
+ exit 1
+fi
+
WHEEL_DIR="$SCRIPT_DIR/../assets/wheels"
RELEASE_TAG="v0.0.0-93025"
GITHUB_REPO="meta-pytorch/forge"
@@ -124,85 +138,6 @@ install_system_packages() {
fi
}
-# Check to see if gh is installed, if not, it will be installed via conda-forge channel
-check_gh_install() {
- if ! command -v gh &> /dev/null; then
- log_warning "GitHub CLI (gh) not found. Installing via Conda..."
- conda install gh --channel conda-forge -y
- log_info "GitHub CLI (gh) installed successfully."
- log_info "Please run 'gh auth login' to authenticate with GitHub."
- else
- log_info "GitHub CLI (gh) already installed."
- fi
-}
-
-# Check wheels exist
-check_wheels() {
- if [ ! -d "$WHEEL_DIR" ]; then
- log_error "Wheels directory not found: $WHEEL_DIR"
- exit 1
- fi
-
- local wheel_count=$(ls -1 "$WHEEL_DIR"/*.whl 2>/dev/null | wc -l)
- log_info "Found $wheel_count local wheels"
-}
-
-# Download vLLM wheel from GitHub releases
-download_vllm_wheel() {
- log_info "Downloading vLLM wheel from GitHub releases..."
-
- # Check if gh is installed
- if ! command -v gh &> /dev/null; then
- log_error "GitHub CLI (gh) is required to download vLLM wheel"
- log_info "Install it with: sudo dnf install gh"
- log_info "Then run: gh auth login"
- exit 1
- fi
-
- # Get the vLLM wheel filename from the release
- local vllm_wheel_name
- vllm_wheel_name=$(gh release view "$RELEASE_TAG" --repo "$GITHUB_REPO" --json assets --jq '.assets[] | select(.name | contains("vllm")) | .name' | head -1)
-
- if [ -z "$vllm_wheel_name" ]; then
- log_error "Could not find vLLM wheel in release $RELEASE_TAG"
- log_info "Make sure the vLLM wheel has been uploaded to the GitHub release"
- exit 1
- fi
- for f in assets/wheels/vllm-*; do
- [ -e "$f" ] || continue # skip if glob didn't match
- if [ "$(basename "$f")" != "$vllm_wheel_name" ]; then
- log_info "Removing stale vLLM wheel: $(basename "$f")"
- rm -f "$f"
- fi
- done
-
- local local_path="$WHEEL_DIR/$vllm_wheel_name"
-
- if [ -f "$local_path" ]; then
- log_info "vLLM wheel already downloaded: $vllm_wheel_name"
- return 0
- fi
-
- log_info "Downloading: $vllm_wheel_name"
-
- # Save current directory and change to wheel directory
- local original_dir=$(pwd)
- cd "$WHEEL_DIR"
- gh release download "$RELEASE_TAG" --repo "$GITHUB_REPO" --pattern "*vllm*"
- local download_result=$?
-
- # Always return to original directory
- cd "$original_dir"
-
- if [ $download_result -eq 0 ]; then
- log_info "Successfully downloaded vLLM wheel"
- else
- log_error "Failed to download vLLM wheel"
- exit 1
- fi
-}
-
-
# Parse command line arguments
parse_args() {
USE_SUDO=false
@@ -241,7 +176,6 @@ main() {
echo "======================"
echo ""
echo "Note: Run this from the root of the forge repository"
- echo "This script requires GitHub CLI (gh) to download large wheels"
if [ "$USE_SUDO" = "true" ]; then
echo "System packages will be installed via system package manager (requires sudo)"
check_sudo
@@ -250,24 +184,29 @@ main() {
fi
echo ""
- check_conda_env
- check_wheels
-
# Install openssl as we overwrite the default version when we update LD_LIBRARY_PATH
conda install -y openssl
install_system_packages "$USE_SUDO"
- check_gh_install
- download_vllm_wheel
- log_info "Installing PyTorch nightly..."
- pip install torch==$PYTORCH_VERSION --index-url https://download.pytorch.org/whl/nightly/cu129
+ log_info "Installing PyTorch ..."
+ pip install torch==$PYTORCH_VERSION --index-url https://download.pytorch.org/whl/cu128
+
+ # Install vLLM and its requirements
+ pip install -r .github/packaging/vllm_reqs_12_8.txt
+ pip install six
+ pip install "setuptools<80"
+ python -m pip install vllm --no-cache-dir --index-url https://download.pytorch.org/whl/preview/forge
+
+ # Install monarch
+ pip install torchmonarch==$MONARCH_VERSION
- log_info "Installing all wheels (local + downloaded)..."
- pip install "$WHEEL_DIR"/*.whl
+ # Install torchtitan and torchstore
+ pip install torchtitan==$TORCHTITAN_VERSION
+ pip install torchstore==$TORCHSTORE_VERSION
log_info "Installing Forge from source..."
- pip install -e .
+ pip install -e ".[dev]"
# Set up environment
log_info "Setting up environment..."
@@ -287,7 +226,7 @@ main() {
local cuda_activation_script="${conda_env_dir}/etc/conda/activate.d/cuda_env.sh"
cat > "$cuda_activation_script" << 'EOF'
# CUDA environment for Forge
-export CUDA_VERSION=12.9
+export CUDA_VERSION=12.8
export NVCC=/usr/local/cuda-${CUDA_VERSION}/bin/nvcc
export CUDA_NVCC_EXECUTABLE=/usr/local/cuda-${CUDA_VERSION}/bin/nvcc
export CUDA_HOME=/usr/local/cuda-${CUDA_VERSION}
diff --git a/src/forge/actors/__init__.py b/src/forge/actors/__init__.py
index 54e450cd7..772e2e216 100644
--- a/src/forge/actors/__init__.py
+++ b/src/forge/actors/__init__.py
@@ -4,19 +4,34 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
-__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer", "TitanRefModel"]
+import warnings
+
+__all__ = [
+ "Generator",
+ "TitanTrainer",
+ "RLTrainer", # Deprecated, use TitanTrainer
+ "ReplayBuffer",
+ "ReferenceModel",
+ "SandboxedPythonCoder",
+]
def __getattr__(name):
- if name == "Policy":
- from .policy import Policy
+ if name == "Generator":
+ from .generator import Generator
- return Policy
- elif name == "PolicyRouter":
- from .policy import PolicyRouter
+ return Generator
+ elif name == "TitanTrainer":
+ from .trainer import TitanTrainer
- return PolicyRouter
+ return TitanTrainer
elif name == "RLTrainer":
+ warnings.warn(
+ "RLTrainer is deprecated and will be removed in a future version. "
+ "Please use TitanTrainer instead.",
+ FutureWarning,
+ stacklevel=2,
+ )
from .trainer import RLTrainer
return RLTrainer
@@ -28,5 +43,9 @@ def __getattr__(name):
from .reference_model import ReferenceModel
return ReferenceModel
+ elif name == "SandboxedPythonCoder":
+ from .coder import SandboxedPythonCoder
+
+ return SandboxedPythonCoder
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
diff --git a/src/forge/actors/_torchstore_utils.py b/src/forge/actors/_torchstore_utils.py
index bc0d55c3b..2d14f7f30 100644
--- a/src/forge/actors/_torchstore_utils.py
+++ b/src/forge/actors/_torchstore_utils.py
@@ -10,6 +10,7 @@
import torch
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.metadata import Metadata as DcpMeta
+from torchstore.transport.buffers import rdma_available
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@@ -69,3 +70,8 @@ def extract_param_name(key: str) -> str:
def get_dcp_whole_state_dict_key(policy_version: int) -> str:
return f"{get_param_prefix(policy_version)}{KEY_DELIM}{DCP_WHOLE_STATE_TAG}"
+
+
+def rdma_enabled() -> bool:
+ """Return if TorchStore thinks we're using RDMA"""
+ return rdma_available()
diff --git a/src/forge/actors/coder.py b/src/forge/actors/coder.py
new file mode 100644
index 000000000..819c488e1
--- /dev/null
+++ b/src/forge/actors/coder.py
@@ -0,0 +1,197 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import os
+import subprocess
+import tempfile
+from pathlib import Path
+
+from forge.controller import ForgeActor
+
+from monarch.actor import endpoint
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+
+
+class _SandboxedPythonCoder:
+ """A sandboxed code execution environment using enroot containers.
+
+ This is a proof of concept of using enroot to provided a sandboxed
+ environment for executing Python code using NVIDIA's enroot technology.
+
+ It automatically manages the entire container lifecycle including image
+ import, container creation, and cleanup.
+
+ The actor follows a three-stage workflow:
+ 1. Image Management: Automatically imports Docker images to enroot .sqsh format
+ 2. Container Lifecycle: Creates fresh container instances for isolated execution
+ 3. Code Execution: Safely runs Python code with proper error handling and output capture
+
+ Dependencies:
+ - enroot: NVIDIA's container runtime (must be installed on host)
+ - Docker images: Accessible via docker:// URLs or local paths
+
+ Args:
+ docker_image: Docker image URL to import (e.g., "docker://python:3.10").
+ Can be any Docker Hub image or custom registry URL.
+ sqsh_image_path: Local filesystem path where the enroot .sqsh image will be stored.
+ If the file doesn't exist, it will be created via enroot import.
+ container_name: Unique name for the enroot container instance. Used for
+ container lifecycle management (create/remove operations).
+
+ """
+
+ def __init__(
+ self,
+ docker_image: str = "docker://python:3.10",
+ sqsh_image_path: str = "python-image.sqsh",
+ container_name: str = "sandbox",
+ ):
+ self.docker_image = docker_image
+ self.sqsh_image_path = sqsh_image_path
+ self.container_name = container_name
+ self._initialized = False
+
+ async def setup(self):
+ """Setup the sandboxed environment."""
+ logging.debug("Setting up sandboxed actor")
+ await self._maybe_create_image()
+ self._recreate()
+
+ async def recreate(self):
+ """Recreates the container instance from the base image."""
+ self._recreate()
+
+ async def _maybe_create_image(self):
+ """Ensure the enroot image exists, import it if necessary."""
+ if not os.path.exists(self.sqsh_image_path):
+ logging.debug(
+ f"Image {self.sqsh_image_path} not found, importing from {self.docker_image}"
+ )
+ result = subprocess.run(
+ ["enroot", "import", "-o", self.sqsh_image_path, self.docker_image],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True,
+ )
+ if result.returncode != 0:
+ raise RuntimeError(f"Failed to import image: {result.stderr}")
+ logging.debug(
+ f"Successfully imported {self.docker_image} to {self.sqsh_image_path}"
+ )
+ else:
+ logging.info(f"Using existing image: {self.sqsh_image_path}")
+
+ def _recreate(self):
+ """(Re)create a clean container instance from the base image."""
+ # Remove any old container
+ logging.debug(f"Removing container {self.container_name}")
+ subprocess.run(
+ ["enroot", "remove", "-f", self.container_name],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+ # Create new container from image
+ result = subprocess.run(
+ ["enroot", "create", "--name", self.container_name, self.sqsh_image_path],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True,
+ )
+ logging.debug(f"Container creation result: {result}")
+ if result.returncode != 0:
+ raise RuntimeError(f"Failed to recreate container: {result.stderr}")
+ self._initialized = True
+ logging.debug("Successfully initialized container")
+
+ async def execute(self, code: str) -> tuple[str, str]:
+ """Executes Python code inside the container and returns the output.
+
+ Args:
+ code: Python source code string to execute.
+
+ Returns:
+ The captured stdout and stderr from the execution, as a
+ (stdout, stderr) tuple of strings.
+ """
+ logging.debug(f"Executing {code}")
+ if not self._initialized:
+ raise RuntimeError("Container not initialized. Call recreate() first.")
+
+ # Write code to a temporary file that we can mount
+ with tempfile.TemporaryDirectory() as tmpdir:
+ code_path = Path(tmpdir) / "script.py"
+ code_path.write_text(code)
+
+ # Run the code inside the container, mounting tmpdir
+ cmd = [
+ "enroot",
+ "start",
+ "--mount",
+ f"{tmpdir}:/work",
+ self.container_name,
+ "python3",
+ "/work/script.py",
+ ]
+ result = subprocess.run(
+ cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True,
+ )
+ output = result.stdout
+ error = result.stderr
+ return output, error
+
+
+class SandboxedPythonCoder(ForgeActor):
+ """Monarch actor wrapper for _SandboxedPythonCoder.
+
+ This is a thin wrapper that makes the sandboxed Python coder available
+ as a distributed Monarch actor. All business logic is in _SandboxedPythonCoder.
+
+ Args:
+ docker_image: Docker image URL to import (e.g., "docker://python:3.10").
+ sqsh_image_path: Local filesystem path where the enroot .sqsh image will be stored.
+ container_name: Unique name for the enroot container instance.
+ """
+
+ def __init__(
+ self,
+ docker_image: str = "docker://python:3.10",
+ sqsh_image_path: str = "python-image.sqsh",
+ container_name: str = "sandbox",
+ ):
+ self._coder = _SandboxedPythonCoder(
+ docker_image=docker_image,
+ sqsh_image_path=sqsh_image_path,
+ container_name=container_name,
+ )
+
+ @endpoint
+ async def setup(self):
+ """Setup the sandboxed environment."""
+ return await self._coder.setup()
+
+ @endpoint
+ async def recreate(self):
+ """Recreate the container instance from the base image."""
+ return await self._coder.recreate()
+
+ @endpoint
+ async def execute(self, code: str) -> tuple[str, str]:
+ """Execute Python code inside the container.
+
+ Args:
+ code: Python source code string to execute.
+
+ Returns:
+ The captured stdout and stderr from the execution, as a
+ (stdout, stderr) tuple of strings.
+ """
+ return await self._coder.execute(code)
diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py
new file mode 100644
index 000000000..4889f183d
--- /dev/null
+++ b/src/forge/actors/generator.py
@@ -0,0 +1,737 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import annotations
+
+import asyncio
+import logging
+import os
+import sys
+import time
+from collections.abc import Mapping
+from copy import copy
+from dataclasses import dataclass, field
+from typing import Optional
+
+import torch
+import torchstore as ts
+
+from forge.actors._torchstore_utils import (
+ extract_param_name,
+ get_dcp_whole_state_dict_key,
+ get_param_key,
+ get_param_prefix,
+ load_tensor_from_dcp,
+ rdma_available,
+)
+
+from forge.controller import (
+ ForgeActor,
+ get_proc_mesh,
+ host_mesh_from_proc,
+ stop_proc_mesh,
+)
+from forge.data_models.completion import Completion
+from forge.data_models.prompt import to_prompt
+from forge.observability.metrics import record_metric, Reduce
+from forge.observability.perf_tracker import Tracer
+from forge.types import ProcessConfig
+from forge.util._shared_tensor import SharedTensor, SharedTensorHandle
+from monarch.actor import current_rank, endpoint, ProcMesh, this_host
+
+from vllm.config import VllmConfig
+
+from vllm.engine.arg_utils import EngineArgs
+from vllm.entrypoints.utils import _validate_truncation_size
+from vllm.executor.multiproc_worker_utils import set_multiprocessing_worker_envs
+from vllm.outputs import CompletionOutput, RequestOutput
+from vllm.sampling_params import RequestOutputKind, SamplingParams
+from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
+from vllm.usage.usage_lib import UsageContext
+from vllm.utils import get_distributed_init_method
+from vllm.v1.core.kv_cache_utils import get_kv_cache_config
+from vllm.v1.core.sched.output import SchedulerOutput
+from vllm.v1.core.sched.scheduler import Scheduler
+from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequest
+from vllm.v1.engine.output_processor import OutputProcessor
+from vllm.v1.engine.parallel_sampling import ParentRequest
+from vllm.v1.engine.processor import Processor
+from vllm.v1.kv_cache_interface import KVCacheConfig
+from vllm.v1.outputs import ModelRunnerOutput
+from vllm.v1.request import Request
+from vllm.v1.structured_output import StructuredOutputManager
+from vllm.worker.worker_base import WorkerWrapperBase
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+@dataclass
+class Generator(ForgeActor):
+ """Instance of a vLLM-based generator.
+
+ This class manually recreates a vLLM engine that mirrors the design of AsyncLLMEngine in v1. The
+ main difference is that all communications are controlled here via Monarch's proc meshes.
+
+ Args:
+ engine_args (EngineArgs): The engine arguments to use for the vLLM engine.
+ sampling_params (SamplingParams): The sampling parameters to use for the vLLM engine.
+ use_dcp_for_weight_sync (bool): Whether to use DCP for NFS-based weight sync. Default depends on
+ whether or not RDMA is enabled in torchstore. If it is, then DCP is disabled. Otherwise, DCP is enabled.
+
+ Example:
+ >>> generator = await Generator.options(procs=1, num_replicas=1, with_gpus=True).as_service(
+ ... engine_args=EngineArgs(...),
+ ... sampling_params=SamplingParams(...),
+ ... )
+ >>> await generator.generate("Tell me a joke")
+ Completion(prompt="Tell me a joke", text="A: Why did the chicken cross the road? B: To get to the other side.",
+ token_ids=[...], logprobs=[...])
+ >>> await generator.shutdown()
+ """
+
+ engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs)
+ sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams)
+ use_dcp_for_weight_sync: bool | None = None
+ prefetch_weights_to_shm: bool = True
+ n_fetcher_procs: int = 8
+
+ def __post_init__(self):
+ super().__init__()
+ self._run_task: asyncio.Task | None = None
+ self._generator_proc: ProcMesh | None = None
+ self._worker_procs: ProcMesh | None = None
+ self.worker: GeneratorWorker | None = None
+ self.running = False
+ self.generator_version: int = 0
+
+ if isinstance(self.engine_args, Mapping):
+ self.engine_args = EngineArgs(**self.engine_args)
+ self.engine_args._is_v1_supported_oracle = lambda *_: True
+ self.vllm_config = self.engine_args.create_engine_config(UsageContext.LLM_CLASS)
+
+ if isinstance(self.sampling_params, Mapping):
+ self.sampling_params = SamplingParams.from_optional(**self.sampling_params)
+ self.sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
+
+ if self.use_dcp_for_weight_sync is None:
+ self.use_dcp_for_weight_sync = not rdma_available()
+ logger.debug(f"{self.use_dcp_for_weight_sync=}")
+
+ @endpoint
+ async def get_vllm_config(self) -> VllmConfig:
+ return self.vllm_config
+
+ @endpoint
+ async def register_worker(self, worker: GeneratorWorker) -> None:
+ self.worker = worker
+ logger.debug("Registered GeneratorWorker on Generator.")
+
+ @classmethod
+ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
+ cls: type["Generator"],
+ *args,
+ **kwargs,
+ ) -> "Generator":
+ """Custom launch for the Generator service with its workers.
+
+ We overwrite the default Service launch method in order to setup Actors (GeneratorWorker) within this "coordinating" Actor.
+ We first create a proc_mesh for the workers, then a proc_mesh for the generator, and then we spawn the workers
+ and the generator in setup.
+ """
+ # Note: get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES
+ process_config: ProcessConfig = ProcessConfig(
+ procs=cls.procs,
+ hosts=cls.hosts,
+ with_gpus=cls.with_gpus,
+ mesh_name=cls.mesh_name,
+ )
+
+ # First, spawn the worker processes which may or may not be
+ # on remote hosts.
+ worker_procs = await get_proc_mesh(process_config=process_config)
+
+ # Then, grab a single host from the workers...
+ host_mesh = await host_mesh_from_proc(worker_procs)
+ singleton_slice = {k: slice(0, 1) for k in host_mesh.extent.keys()}
+ host_mesh = host_mesh.slice(**singleton_slice)
+
+ # We ask the provisioner for a single process on a single host
+ generator_proc_config = copy(process_config)
+ generator_proc_config.procs = 1
+ generator_proc_config.with_gpus = False
+ generator_proc = await get_proc_mesh(
+ process_config=generator_proc_config, host_mesh=host_mesh
+ )
+ # TODO - expand support so name can stick within kwargs
+ actor_name = kwargs.pop("name", cls.__name__)
+ generator = generator_proc.spawn(
+ actor_name,
+ cls,
+ *args,
+ **kwargs,
+ )
+
+ vllm_config = (
+ await generator.get_vllm_config.call_one()
+ ) # Config should be the same across all actors
+ worker = worker_procs.spawn(
+ "vllm_worker", GeneratorWorker, vllm_config=vllm_config
+ )
+ await worker.setup.call()
+ await generator.register_worker.call(worker)
+
+ generator._generator_proc = generator_proc
+ generator._worker_procs = worker_procs
+ await generator.setup.call()
+
+ return generator
+
+ @endpoint
+ async def setup(self):
+ """Mirrors the __init__ of vLLM's LLMEngine."""
+ self.request_id = 0
+ self.requests: dict[str, tuple[ParentRequest | None, asyncio.Future]] = {}
+
+ # TODO: Investigate whether this can be combined with `generator.running`
+ self.accepting_requests = True
+
+ self.request_lock = asyncio.Condition() # Guard for accepting_requests
+ self.update_lock = asyncio.Condition() # Guard for updating requests
+
+ # Setup processors
+ # TODO: move all processing to the Environment
+ # TODO: add support for `log_stats` and `mm_registry`
+ tokenizer = init_tokenizer_from_configs(
+ model_config=self.vllm_config.model_config,
+ scheduler_config=self.vllm_config.scheduler_config,
+ lora_config=self.vllm_config.lora_config,
+ )
+ self.processor = Processor(
+ vllm_config=self.vllm_config, tokenizer=tokenizer, mm_registry=None
+ )
+ self.output_processor = OutputProcessor(tokenizer, log_stats=None)
+
+ # Configure KV caches
+ kv_cache_configs = await self.worker.setup_kv_cache.call()
+ _, kv_cache_config = next(kv_cache_configs.items())
+ self.vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks
+ self.vllm_config.cache_config.num_cpu_blocks = 0
+
+ # Setup scheduler
+ # TODO: Add support for `log_stats`
+ structured_output_manager = StructuredOutputManager(self.vllm_config)
+ self.scheduler = Scheduler(
+ vllm_config=self.vllm_config,
+ kv_cache_config=kv_cache_config,
+ structured_output_manager=structured_output_manager,
+ include_finished_set=False,
+ log_stats=None,
+ )
+ self._start_processing()
+ if self.prefetch_weights_to_shm:
+ self._spawn_fetchers()
+
+ def _spawn_fetchers(self):
+ """Spawn weight fetchers that prefetch weights from torchstore to shared memory."""
+ # TODO: this assumes the generator is on the same host as the worker
+ # and only works for single host generators. Figure out how to support
+ # generators with workers spanned across multiple hosts.
+ fetcher_procs = this_host().spawn_procs(
+ per_host={"procs": self.n_fetcher_procs}
+ )
+ self._fetcher_procs = fetcher_procs
+ self.weight_fetchers = fetcher_procs.spawn("weight_fetcher", _WeightFetcher)
+
+ def _start_processing(self):
+ if self._run_task is None or self._run_task.done():
+ self._run_task = asyncio.create_task(self.run())
+
+ async def _drop_shared_memory(self, state_dict: dict[str, SharedTensorHandle]):
+ for handle in state_dict.values():
+ handle.drop()
+
+ async def _fetch_weights(
+ self,
+ version: int,
+ ) -> dict[str, SharedTensorHandle]:
+ """Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}."""
+ prefix = get_param_prefix(version)
+ matching_keys = await ts.keys(prefix)
+ hf_param_names = [extract_param_name(key) for key in matching_keys]
+
+ n_fetchers = self.weight_fetchers.size()
+
+ def split_keys(keys):
+ return [keys[i::n_fetchers] for i in range(n_fetchers)]
+
+ futures = []
+ for i, names in enumerate(split_keys(hf_param_names)):
+ fut = self.weight_fetchers.slice(procs=i).fetch.call_one(
+ version=version, param_names=names
+ )
+ futures.append(fut)
+
+ sub_state_dicts = [await fut for fut in futures]
+
+ state_dict = {}
+ for sd in sub_state_dicts:
+ state_dict.update(sd)
+
+ return state_dict
+
+ @endpoint
+ async def generate(
+ self,
+ prompt: str,
+ *,
+ priority: int = 0,
+ sampling_params: SamplingParams | None = None,
+ ) -> list[Completion]:
+ """Generate a response for the given prompt
+
+ Args:
+ prompt (str): The prompt to generate a response for.
+ priority (int, optional): The priority of the request. Defaults to 0.
+ sampling_params (SamplingParams, optional): Sampling parameters to use for this request.
+ If not provided, uses self.sampling_params.
+
+ Returns:
+ list[Completion]: n completions from vLLM based on your prompt.
+ """
+ t = Tracer("generator_perf/generate", timer="gpu")
+ t.start()
+ record_metric("generator/generate/count_requests", 1, Reduce.SUM)
+
+ if sampling_params is not None:
+ # as in `post_init`
+ sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
+
+ params = sampling_params or self.sampling_params
+
+ self.request_id += 1 % sys.maxsize
+ request_id = str(self.request_id)
+
+ tokenization_kwargs = {}
+ # TODO: add truncation support https://github.com/vllm-project/vllm/issues/4507
+ truncate_prompt_tokens = params.truncate_prompt_tokens
+ _validate_truncation_size(
+ self.vllm_config.model_config.max_model_len,
+ truncate_prompt_tokens,
+ tokenization_kwargs,
+ )
+ prompt_str, request = self.processor.process_inputs(
+ request_id=request_id,
+ prompt={"prompt": prompt},
+ params=params,
+ arrival_time=None,
+ tokenization_kwargs=tokenization_kwargs,
+ trace_headers=None,
+ priority=priority,
+ data_parallel_rank=None, # We do not support DP
+ )
+ # Wait until we're accepting requests (releases lock while waiting)
+ # If accepting_requests is True, continue immediately (holding the lock)
+ # If False, release lock, wait for notification, re-acquire and recheck
+ async with self.request_lock:
+ await self.request_lock.wait_for(lambda: self.accepting_requests)
+
+ # Explicitly keeping the redundant logic to make it easier to pick up vLLM changes
+ if (num_samples := params.n) == 1:
+ self.output_processor.add_request(request, prompt_str, None, 0)
+ request, _ = self._preprocess_add_request(request)
+ request_fut = asyncio.Future()
+ self.requests[request_id] = (None, request_fut)
+ self.scheduler.add_request(request)
+ else:
+ parent_req = ParentRequest(request_id, params)
+ for idx in range(num_samples):
+ # Note: `get_child_info` mutates ParentRequest to track the
+ # generated child request
+ child_request_id, params_child = parent_req.get_child_info(idx)
+ child_request = request if idx == num_samples - 1 else copy(request)
+ child_request.request_id = child_request_id
+ child_request.sampling_params = params_child
+ self.output_processor.add_request(
+ child_request, prompt_str, parent_req, idx
+ )
+ child_request, _ = self._preprocess_add_request(child_request)
+ self.scheduler.add_request(child_request)
+ request_fut = asyncio.Future()
+ self.requests[request_id] = (parent_req, request_fut)
+
+ completions = await request_fut
+
+ # Log some metrics
+ record_metric(
+ "generator/generate/count_sequences_completed",
+ len(completions),
+ Reduce.SUM,
+ )
+
+ t.stop()
+ return completions
+
+ def _preprocess_add_request(
+ self, request: EngineCoreRequest
+ ) -> tuple[Request, int]:
+ """(forge/issues/332) Will require attention when we bump vllm versions
+ https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419
+ """
+ if request.mm_hashes is not None:
+ raise NotImplementedError("Support for mm_hash is not implemented yet.")
+ req = Request.from_engine_core_request(request)
+ if req.use_structured_output:
+ self.scheduler.structured_output_manager.grammar_init(request)
+ return req, 0
+
+ async def run(self) -> None:
+ """Schedule, execute, and make output.
+ https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L276
+ """
+ # TODO: move postprocessing out of loop to not block
+ self.running = True
+ while self.running:
+ scheduler_output = self.scheduler.schedule()
+ worker_outputs = await self.worker.execute_model.call(scheduler_output)
+
+ # The results of `execute_model` are gathered on the driver rank (rank 0)
+ _, worker_output = next(worker_outputs.items())
+ outputs = self.scheduler.update_from_output(scheduler_output, worker_output)
+ outputs = outputs.get(0) or EngineCoreOutputs()
+ await asyncio.sleep(0) # Release control before processing outputs
+
+ processed_outputs = self.output_processor.process_outputs(
+ outputs.outputs,
+ engine_core_timestamp=outputs.timestamp,
+ iteration_stats=None, # TODO: add support for `iteration_stats`
+ )
+ for request_output in processed_outputs.request_outputs:
+ if request_output.finished:
+ completions = self._to_completions(request_output)
+ _, fut = self.requests.pop(request_output.request_id)
+ fut.set_result(completions)
+
+ # Notify waiters if queue is drained
+ async with self.request_lock:
+ if len(self.requests) == 0:
+ self.request_lock.notify_all()
+
+ @endpoint
+ async def update_weights(self, version: int) -> None:
+ """Update weights on base model from a generator version to be found in a torchstore volume.
+
+ Args:
+ generator_version (int): Generator version from which to update. This will correspond to a key in a
+ torchstore volume.
+
+ Example:
+ >>> trainer.train_step(...)
+ >>> version += 1
+ >>> await trainer.push_weights()
+ >>> generator.update_weights(version)
+ """
+ # TODO: enable shared memory prefetch for DCP-based weight sync
+ if self.prefetch_weights_to_shm and not self.use_dcp_for_weight_sync:
+ logger.info(f"[Generator] Fetching weights for v{version} to shared memory")
+ fetch_fut = asyncio.create_task(self._fetch_weights(version))
+ else:
+ fetch_fut = None
+ # Serialize updates (only one update at a time)
+ async with self.update_lock:
+ # Grab the lock to stop accepting requests and wait on pending requests
+ async with self.request_lock:
+ self.accepting_requests = False
+ curr_requests = [fut for _, fut in self.requests.values()]
+
+ if curr_requests:
+ # Record pending requests count
+ record_metric(
+ "generator_perf/update_weights/sum_pending_gen_requests",
+ len(curr_requests),
+ Reduce.SUM,
+ )
+ logger.debug(f"Waiting for {len(curr_requests)} pending requests")
+
+ # Start timing the wait
+ wait_start = time.perf_counter()
+
+ # Wait until all pending requests have been processed
+ # TODO: If generating long sequences, this might be long and will block
+ # generator weight updates
+ await self.request_lock.wait_for(lambda: len(self.requests) == 0)
+
+ if curr_requests:
+ wait_duration = time.perf_counter() - wait_start
+ record_metric(
+ "generator_perf/update_weights/avg_waiting_for_generation_duration_s",
+ wait_duration,
+ Reduce.MEAN,
+ )
+
+ logger.debug(f"Starting weight update on {self.__class__.__name__}")
+
+ if fetch_fut is not None:
+ fetched_weights = await fetch_fut
+ # Call update_weights on every policy_worker
+ await self.worker.update_weights.call(
+ shared_memory_state_dict=fetched_weights
+ )
+ await self._drop_shared_memory(fetched_weights)
+ else:
+ await self.worker.update_weights.call(version=version)
+ self.generator_version = version
+
+ # After updating the weights, we need to reset the KV cache
+ self.scheduler.reset_prefix_cache()
+
+ # Resume accepting requests and wake up any waiting generate() calls
+ async with self.request_lock:
+ self.accepting_requests = True
+ self.request_lock.notify_all()
+
+ logger.info(f"Weight update completed (now v{self.generator_version})")
+
+ @endpoint
+ async def _reset_prefix_cache(self):
+ self.scheduler.reset_prefix_cache()
+
+ @endpoint
+ async def get_version(self) -> int:
+ """Get the current generator version."""
+ return self.generator_version
+
+ @endpoint
+ async def stop(self):
+ self.running = False
+
+ def _to_completions(self, request_output: RequestOutput) -> list[Completion]:
+ """Convert a vLLM RequestOutput to a list of Completion objects."""
+ completions = []
+ original_prompt = request_output.prompt
+ prompt_token_ids = request_output.prompt_token_ids
+ for output in request_output.outputs:
+ completions.append(
+ Completion(
+ # TODO: the to_prompt encoding will be different from the original.
+ # This is okay for now, since I don't see any direct usage of prompt using completion object.
+ prompt=to_prompt(original_prompt),
+ stop_reason=output.finish_reason,
+ text=output.text,
+ prompt_ids=torch.tensor(prompt_token_ids),
+ token_ids=torch.tensor(output.token_ids),
+ logprobs=self._extract_logprobs(output),
+ generator_version=self.generator_version,
+ metadata={"num_cached_tokens": request_output.num_cached_tokens},
+ )
+ )
+ return completions
+
+ def _extract_logprobs(self, sample: CompletionOutput) -> torch.Tensor | None:
+ if sample.logprobs is not None:
+ return torch.tensor(
+ [
+ top_k_dict[token].logprob
+ for token, top_k_dict in zip(sample.token_ids, sample.logprobs)
+ ]
+ )
+ return None
+
+ @classmethod
+ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
+ cls: type["Generator"], actor: "Generator"
+ ):
+ assert (
+ actor._generator_proc is not None
+ ), "Tried to shutdown a generator that was not initialized correctly"
+ assert (
+ actor._worker_procs is not None
+ ), "Tried to shutdown a generator that was not initialized correctly"
+
+ # TODO - may want to expand stop to gracefully respond to
+ # ongoing requests.
+ await actor.stop.call()
+ await stop_proc_mesh(actor._worker_procs)
+ await stop_proc_mesh(actor._generator_proc)
+ await stop_proc_mesh(actor._fetcher_procs)
+
+ @endpoint
+ async def save_model_params(self):
+ """Save model parameters before weight update, used for testing purposes only."""
+ logger.info("[Generator] save model parameters for testing.")
+ await self.worker.save_model_params.call()
+
+ @endpoint
+ async def validate_model_params(self, validate_fn):
+ """Validate updated model params using validate_fn."""
+ logger.info("[Generator] start validating model parameters.")
+ return await self.worker.validate_model_params.call(validate_fn)
+
+
+@dataclass
+class GeneratorWorker(ForgeActor):
+ """Mirrors a vLLM GPUWorker
+ https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/worker/gpu_worker.py
+
+ In general, this class should not be instantiated or called directly. Rather, the Generator controls
+ the creation and invocation of all GeneratorWorker.
+ """
+
+ vllm_config: VllmConfig
+ # TODO: Remove below param
+ _test_prev_params = {}
+
+ def __post_init__(self):
+ super().__init__()
+
+ @endpoint
+ async def setup(self):
+ self.rank = current_rank().rank
+ os.environ["RANK"] = str(self.rank)
+ parallel_config = self.vllm_config.parallel_config
+ set_multiprocessing_worker_envs(parallel_config)
+ ip, port = os.getenv("MASTER_ADDR"), os.getenv("MASTER_PORT")
+ distributed_init_method = get_distributed_init_method(ip, port)
+ all_kwargs = [{}] * parallel_config.world_size
+ local_rank = self.rank % torch.accelerator.device_count()
+ is_driver_worker = self.rank % parallel_config.tensor_parallel_size == 0
+ all_kwargs[self.rank] = {
+ "vllm_config": self.vllm_config,
+ "local_rank": local_rank,
+ "rank": self.rank,
+ "distributed_init_method": distributed_init_method,
+ "is_driver_worker": is_driver_worker,
+ }
+ self.worker = WorkerWrapperBase(self.vllm_config, self.rank)
+ self.worker.init_worker(all_kwargs)
+ self.worker.init_device()
+ self.worker.load_model()
+
+ @endpoint
+ async def setup_kv_cache(self) -> KVCacheConfig:
+ """https://github.com/vllm-project/vllm/blob/5c7fe25491825b95936c011a43337c7d4fb7e472/vllm/v1/engine/core.py#L199"""
+ kv_cache_spec = self.worker.get_kv_cache_spec()
+ if kv_cache_spec is not None:
+ available_gpu_memory = self.worker.determine_available_memory()
+ else:
+ # Attention free models don't need memory for kv cache
+ available_gpu_memory = 0
+
+ # Get the kv cache tensor size
+ kv_cache_config = get_kv_cache_config(
+ self.vllm_config, kv_cache_spec, available_gpu_memory
+ )
+ # TODO: unify configs across TorchStore
+ # unify_kv_cache_configs(kv_cache_configs)
+ self.vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks
+ self.vllm_config.cache_config.num_cpu_blocks = 0
+
+ # Initialize kv cache and warmup the execution:
+ # from multiproc_executor.py:MultiprocExecutor.initialize_from_config
+ kv_cache_configs = [None] * self.vllm_config.parallel_config.world_size
+ kv_cache_configs[self.rank] = kv_cache_config
+ self.worker.initialize_from_config(kv_cache_configs)
+ self.worker.compile_or_warm_up_model()
+ self.worker.initialize_cache(kv_cache_config.num_blocks, 0)
+ return kv_cache_config
+
+ @endpoint
+ async def execute_model(self, schedule: SchedulerOutput) -> ModelRunnerOutput:
+ return self.worker.execute_model(schedule)
+
+ @endpoint
+ async def update_weights(
+ self,
+ version: Optional[int] = None,
+ *,
+ shared_memory_state_dict: Optional[dict[str, SharedTensorHandle]] = None,
+ ) -> None:
+ model = self.worker.model_runner.model
+ if shared_memory_state_dict is not None:
+ logger.info("[PolicyWorker] update weights from shared memory.")
+ loaded_weights = set()
+ for name, param_handle in shared_memory_state_dict.items():
+ # Use context manager for automatic cleanup
+ with param_handle.to_shared_tensor() as shared_tensor:
+ param = shared_tensor.tensor
+ loaded = model.load_weights([(name, param)])
+ del param
+ loaded_weights.update(loaded)
+ logger.info(f"[PolicyWorker] updated {len(loaded_weights)} parameters")
+ return
+ # normal update_weights without shared memory prefetching
+ if version is None:
+ raise ValueError(
+ "version must be provided if not using shared_memory_state_dict"
+ )
+ logger.info("[PolicyWorker] update weights from torchstore.")
+ prefix = get_param_prefix(version)
+ matching_keys = await ts.keys(prefix)
+ dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version)
+ use_dcp_for_weight_sync = dcp_whole_state_dict_key in matching_keys
+ loaded_weights = set()
+
+ if use_dcp_for_weight_sync:
+ dcp_handle = await ts.get(dcp_whole_state_dict_key)
+ hf_param_names = dcp_handle.param_names
+ for name in hf_param_names:
+ param = load_tensor_from_dcp(dcp_handle, name)
+ loaded = model.load_weights([(name, param)])
+ del param
+ loaded_weights.update(loaded)
+ else:
+ hf_param_names = [extract_param_name(key) for key in matching_keys]
+ # We can't pass a generator since vllm load_weights is not async.
+ # Instead, we just call load_weights with one parameter at a time.
+ for name in hf_param_names:
+ param_key = get_param_key(version, name)
+ param = await ts.get(param_key)
+ loaded = model.load_weights([(name, param)])
+ del param
+ loaded_weights.update(loaded)
+
+ @endpoint
+ async def save_model_params(self):
+ """Save model parameters before weight update, used for testing purposes only."""
+ logger.info("[GeneratorWorker] save model parameters for testing.")
+ for name, param in self.worker.model_runner.model.named_parameters():
+ self._test_prev_params[name] = param.detach().cpu()
+ logger.info(
+ "[GeneratorWorker] finished saving model parameters, len = %d",
+ len(self._test_prev_params),
+ )
+
+ @endpoint
+ async def validate_model_params(self, validate_fn):
+ """Validate updated model params using validate_fn."""
+ logger.info("[GeneratorWorker] start validating model parameters.")
+ return validate_fn(
+ self._test_prev_params, self.worker.model_runner.model, logger
+ )
+
+
+class _WeightFetcher(ForgeActor):
+ """Fetches weights from torchstore and loads them into shared memory.
+ This has to be colocated with the GeneratorWorker."""
+
+ @endpoint
+ async def fetch(
+ self,
+ *,
+ version: int,
+ param_names: list[str],
+ ) -> dict[str, SharedTensorHandle]:
+ """Fetch weights from torchstore and load them into shared memory."""
+ sd = {}
+ for name in param_names:
+ param_key = get_param_key(version, name)
+ param = await ts.get(param_key)
+ # Use context manager to ensure cleanup after getting handle
+ with SharedTensor(tensor=param) as shared_tensor:
+ handle = shared_tensor.get_handle()
+ sd[name] = handle
+ del param # Explicitly free the tensor after copying to shared memory
+ return sd
diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py
deleted file mode 100644
index 4b61f096c..000000000
--- a/src/forge/actors/policy.py
+++ /dev/null
@@ -1,773 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-from __future__ import annotations
-
-import asyncio
-import logging
-import os
-import sys
-import time
-from collections.abc import Mapping
-from copy import copy
-from dataclasses import asdict, dataclass, field, fields
-
-import torch
-import torch.distributed.checkpoint as dcp
-import torchstore as ts
-from monarch.actor import current_rank, endpoint, ProcMesh
-from torchstore.state_dict_utils import DELIM
-from vllm.config import VllmConfig
-
-from vllm.engine.arg_utils import EngineArgs
-from vllm.entrypoints.utils import _validate_truncation_size
-from vllm.executor.multiproc_worker_utils import set_multiprocessing_worker_envs
-from vllm.lora.request import LoRARequest
-from vllm.outputs import CompletionOutput, RequestOutput
-from vllm.sampling_params import GuidedDecodingParams, RequestOutputKind, SamplingParams
-from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
-from vllm.usage.usage_lib import UsageContext
-from vllm.utils import get_distributed_init_method
-from vllm.v1.core.kv_cache_utils import get_kv_cache_config
-from vllm.v1.core.sched.output import SchedulerOutput
-from vllm.v1.core.sched.scheduler import Scheduler
-from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequest
-from vllm.v1.engine.output_processor import OutputProcessor
-from vllm.v1.engine.parallel_sampling import ParentRequest
-from vllm.v1.engine.processor import Processor
-from vllm.v1.request import Request
-from vllm.v1.structured_output import StructuredOutputManager
-from vllm.worker.worker_base import WorkerWrapperBase
-
-from forge.actors._torchstore_utils import (
- extract_param_name,
- get_dcp_whole_state_dict_key,
- get_param_key,
- get_param_prefix,
- load_tensor_from_dcp,
-)
-
-from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
-from forge.data.sharding import VLLMSharding
-from forge.data_models.completion import Completion
-from forge.data_models.prompt import to_prompt
-from forge.interfaces import Policy as PolicyInterface
-from forge.observability.metrics import record_metric, Reduce
-from forge.observability.perf_tracker import Tracer
-from forge.types import ProcessConfig
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-
-
-@dataclass
-class SamplingConfig:
- """
- Overrides for vLLMs sampling params.
-
- Note: We'll want to tie this closer to or directly use vllm's
- SamplingParams. It is currently used to track a supported
- subset
-
- Args:
- n: Number of samples to generate.
- guided_decoding: Whether to use guided decoding.
- max_tokens: Maximum number of tokens to generate.
- """
-
- n: int = 1
- guided_decoding: bool = False
- max_tokens: int = 512
- temperature: float = 1.0
- top_p: float = 1.0
- logprobs: int = 1
-
- def __post_init__(self):
- super().__init__()
- gd_params = None
- if self.guided_decoding:
- gd_params = GuidedDecodingParams(choice=["Positive", "Negative"])
- self.guided_decoding = gd_params
-
- @classmethod
- def from_dict(cls, d: Mapping):
- d = dict(d)
- all_fields = set(cls.__dataclass_fields__.keys())
- valid_args = {k: v for k, v in d.items() if k in all_fields}
- return cls(**valid_args)
-
-
-@dataclass
-class EngineConfig(EngineArgs):
- """
- EngineConfig extends EngineArgs with worker-specific fields.
- Overlapping keys in input dict will override EngineArgs defaults.
- """
-
- model: str = "meta-llama/Llama-3.1-8B-Instruct"
- tensor_parallel_size: int = 1
- pipeline_parallel_size: int = 1
- enforce_eager: bool = False
- enable_expert_parallel: bool = False
-
- # Original method returns False when not run in the main thread
- _is_v1_supported_oracle = lambda *_: True
-
- @classmethod
- def from_dict(cls, d: Mapping):
- d = dict(d)
- all_fields = [f.name for f in fields(cls)]
- valid_args = {k: v for k, v in d.items() if k in all_fields}
- return cls(**valid_args)
-
- def create_vllm_config(self) -> VllmConfig:
- """Converts the current EngineConfig into vLLM's vLLMConfig."""
- # Note: EngineArgs.create_engine_config
- # creates a VllmConfig
- return self.create_engine_config(UsageContext.LLM_CLASS)
-
-
-@dataclass
-class Policy(PolicyInterface):
- engine_config: EngineConfig | Mapping = field(default_factory=EngineConfig)
- sampling_config: SamplingConfig | Mapping = field(default_factory=SamplingConfig)
- use_vllm_builtin_load: bool = True
- available_devices: str | None = None
- use_dcp: bool = True
- # Gets set up by setup
- sampling_params: SamplingParams | None = None
- lora_request: LoRARequest | None = None
- tokenization_kwargs: dict = field(default_factory=dict)
- policy_worker: "PolicyWorker" = None
- policy_version: int | None = None
-
- def __post_init__(self):
- super().__init__()
- self._run_task: asyncio.Task | None = None
- self._policy_proc: ProcMesh | None = None
- self._worker_procs: ProcMesh | None = None
- self.running = False
- if isinstance(self.engine_config, Mapping):
- self.engine_config = EngineConfig.from_dict(self.engine_config)
- if isinstance(self.sampling_config, Mapping):
- self.sampling_config = SamplingConfig.from_dict(self.sampling_config)
- # No conversion needed for boolean flag
-
- @classmethod
- async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
- cls: type["Policy"],
- *,
- engine_config: EngineConfig | Mapping = EngineConfig(),
- sampling_config: SamplingConfig | Mapping = SamplingConfig(),
- available_devices: str | None = None,
- use_dcp: bool = True,
- **kwargs,
- ) -> "Policy":
- # Note - get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES
- # automatically.
- process_config: ProcessConfig = ProcessConfig(
- procs=cls.procs,
- hosts=cls.hosts,
- with_gpus=cls.with_gpus,
- mesh_name=cls.mesh_name,
- )
- worker_procs = await get_proc_mesh(process_config=process_config)
-
- # TODO - issues/144 we will want to ensure colocation with workers
- # We're currently locating the Policy on the local host proc mesh
- # vLLM initialization without setting env variables at proc_mesh creation
- # level leads to issues.
- # Once we can create multiple proc meshes on a host mesh, we can ensure
- # host colocation
- policy_proc_config = copy(process_config)
- policy_proc_config.procs = 1
- policy_proc_config.hosts = None
- policy_proc_config.with_gpus = False
-
- policy_proc = await get_proc_mesh(process_config=policy_proc_config)
-
- if isinstance(engine_config, Mapping):
- engine_config = EngineConfig.from_dict(engine_config)
-
- vllm_config = engine_config.create_vllm_config()
- # TODO (felipemello): LocalFetcherActor doesnt spawn with this, so cannot
- # do logging within PolicyWorker
- workers = worker_procs.spawn(
- "vllm_worker", PolicyWorker, vllm_config=vllm_config, use_dcp=use_dcp
- )
-
- if isinstance(sampling_config, Mapping):
- sampling_config = SamplingConfig(**sampling_config)
-
- # TODO - expand support so name can stick within kwargs
- actor_name = kwargs.pop("name", cls.__name__)
- policy = policy_proc.spawn(
- actor_name,
- cls,
- engine_config=engine_config,
- sampling_config=sampling_config,
- available_devices=available_devices,
- policy_worker=workers,
- **kwargs,
- )
- policy._policy_proc = policy_proc
- policy._worker_procs = worker_procs
- await policy.setup.call()
- return policy
-
- @classmethod
- async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
- cls: type["Policy"], actor: "Policy"
- ):
- assert (
- actor._policy_proc is not None
- ), "Tried to shutdown a policy that was not initialized correctly"
- assert (
- actor._worker_procs is not None
- ), "Tried to shutdown a policy that was not initialized correctly"
-
- # TODO - may want to expand stop to gracefully respond to
- # ongoing requests.
- await actor.stop.call()
- await stop_proc_mesh(actor._worker_procs)
- await stop_proc_mesh(actor._policy_proc)
-
- @endpoint
- async def setup(self):
- # Set up policy_worker
- assert self.policy_worker is not None, "Policy worker should not be None"
- await self.policy_worker.setup.call()
-
- self.request_id = 0
- self.policy_version = 0
- self.requests: dict[str, tuple[None | ParentRequest, asyncio.Future]] = {}
-
- # TODO: Investigate whether this can be combined with `policy.running`
- # Whether this policy is accepting requests.
- self.accepting_requests = True
- # Guard for accepting_requests
- self.request_lock = asyncio.Condition()
- # Guard for updating requests
- self.update_lock = asyncio.Condition()
-
- self.vllm_config: VllmConfig = self.engine_config.create_vllm_config()
-
- # Setup sampling params
- self.sampling_params = get_default_sampling_params(
- self.vllm_config, overrides=asdict(self.sampling_config)
- )
-
- # Setup processors
- # TODO: move all processing to the Environment
- # TODO: add support for `log_stats` and `mm_registry`
- tokenizer = init_tokenizer_from_configs(
- model_config=self.vllm_config.model_config,
- scheduler_config=self.vllm_config.scheduler_config,
- lora_config=self.vllm_config.lora_config,
- )
- self.processor = Processor(
- vllm_config=self.vllm_config, tokenizer=tokenizer, mm_registry=None
- )
- self.output_processor = OutputProcessor(tokenizer, log_stats=None)
-
- # Setup scheduler
- # TODO: Add support for `log_stats`
- kv_cache_configs = await self.policy_worker.setup_kv_cache.call()
- _, kv_cache_config = next(kv_cache_configs.items())
- self.vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks
- self.vllm_config.cache_config.num_cpu_blocks = 0
-
- structured_output_manager = StructuredOutputManager(self.vllm_config)
- self.scheduler = Scheduler(
- vllm_config=self.vllm_config,
- kv_cache_config=kv_cache_config,
- structured_output_manager=structured_output_manager,
- include_finished_set=False,
- log_stats=None,
- )
- self.start_processing()
-
- def start_processing(self):
- """Start the replica's processing loop if not already running."""
- if self._run_task is None or self._run_task.done():
- self._run_task = asyncio.create_task(self.run())
-
- @endpoint
- async def generate(self, prompt: str, priority: int = 0) -> list[Completion]:
- """Generate a response for the given prompt
-
- Args:
- prompt (str): The prompt to generate a response for.
- priority (int, optional): The priority of the request. Defaults to 0.
-
- Returns:
- RequestOutput: vLLM class with the generated response.
- """
- t = Tracer("policy_perf/generate", timer="gpu")
- t.start()
-
- record_metric("policy/generate/count_requests", 1, Reduce.SUM)
-
- self.request_id += 1 % sys.maxsize
- request_id = str(self.request_id) # implement from a counter
-
- # Wraps prompt into a dict
- prompt_dict: dict[str, str] = convert_input(prompt=prompt)
-
- # truncate prmpt
- tokenization_kwargs = self.tokenization_kwargs or {}
- # TODO: add truncation support https://github.com/vllm-project/vllm/issues/4507
- truncate_prompt_tokens = self.sampling_params.truncate_prompt_tokens
- _validate_truncation_size(
- self.vllm_config.model_config.max_model_len,
- truncate_prompt_tokens,
- tokenization_kwargs,
- )
- t.step("prompt_truncation")
-
- # process and tokenize prompt
- prompt_str, request = self.processor.process_inputs(
- request_id=request_id,
- prompt=prompt_dict,
- params=self.sampling_params,
- arrival_time=None,
- lora_request=self.lora_request,
- tokenization_kwargs=tokenization_kwargs,
- trace_headers=None,
- priority=priority,
- data_parallel_rank=None,
- )
- t.step("process_inputs")
-
- # Wait until we're accepting requests (releases lock while waiting)
- # If accepting_requests is True, continue immediately (holding the lock)
- # If False, release lock, wait for notification, re-acquire and recheck
- async with self.request_lock:
- await self.request_lock.wait_for(lambda: self.accepting_requests)
-
- # Explicitly keeping the redundant logic to make it easier to pick up
- # vllm changes
- # TODO: Clean up before release
- if (num_samples := self.sampling_params.n) == 1:
- self.output_processor.add_request(request, prompt_str, None, 0)
- request, _ = self.preprocess_add_request(request)
- request_fut = asyncio.Future()
- self.requests[request_id] = (None, request_fut)
-
- self.scheduler.add_request(request)
- else:
- parent_req = ParentRequest(request_id, self.sampling_params)
- for idx in range(num_samples):
- # Note: `get_child_info` mutates ParentRequest to track the
- # generated child request
- child_request_id, params = parent_req.get_child_info(idx)
- child_request = request if idx == num_samples - 1 else copy(request)
- child_request.request_id = child_request_id
- child_request.sampling_params = params
- self.output_processor.add_request(
- child_request, prompt_str, parent_req, idx
- )
- child_request, _ = self.preprocess_add_request(child_request)
-
- self.scheduler.add_request(child_request)
- request_fut = asyncio.Future()
- self.requests[request_id] = (parent_req, request_fut)
-
- completions = await request_fut
- t.step("generate")
-
- record_metric(
- "policy/generate/count_sequences_completed",
- len(completions),
- Reduce.SUM,
- )
-
- for completion in completions:
- num_generated_tokens = len(completion.token_ids)
- record_metric(
- "policy/generate/sum_tokens_generated",
- num_generated_tokens,
- Reduce.SUM,
- )
-
- record_metric(
- "policy/generate/avg_tokens_generated",
- num_generated_tokens,
- Reduce.MEAN,
- )
-
- t.stop()
-
- return completions
-
- # Abstracted to match vllm
- # https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419
- def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
- if request.mm_hashes is not None:
- raise NotImplementedError("Support for mm_hash is not implemented yet.")
- request: Request = Request.from_engine_core_request(request)
- if request.use_structured_output:
- self.scheduler.structured_output_manager.grammar_init(request)
-
- return request, 0 # Unused Arg: Current Wave
-
- async def run(self):
- # TODO: add support for `iteration_stats`
- # TODO: move postprocessing out of loop to not block
- self.running = True
- while self.running:
-
- scheduler_output = self.scheduler.schedule()
-
- worker_outputs = await self.policy_worker.execute_model.call(
- scheduler_output
- )
-
- # the results of `execute_model` is gathered on the driver rank (rank 0)
- _, worker_output = next(worker_outputs.items())
- outputs = self.scheduler.update_from_output(scheduler_output, worker_output)
- outputs = outputs.get(0) or EngineCoreOutputs()
- await asyncio.sleep(0) # Release control before processing outputs
-
- processed_outputs = self.output_processor.process_outputs(
- outputs.outputs,
- engine_core_timestamp=outputs.timestamp,
- iteration_stats=None,
- )
-
- for request_output in processed_outputs.request_outputs:
- if request_output.finished:
- completions = self._to_completions(request_output)
- _, fut = self.requests.pop(request_output.request_id)
- fut.set_result(completions)
-
- # Notify waiters if queue is drained
- async with self.request_lock:
- if len(self.requests) == 0:
- self.request_lock.notify_all()
-
- @endpoint
- async def update_weights(self, policy_version: int):
- # Serialize updates (only one update at a time)
- async with self.update_lock:
- # Grab the lock to stop accepting requests and wait on pending requests
- async with self.request_lock:
- self.accepting_requests = False
-
- curr_requests = [fut for _, fut in self.requests.values()]
- if curr_requests:
- # Record pending requests metrics
- record_metric(
- "policy_perf/update_weights/avg_pending_requests",
- len(curr_requests),
- Reduce.MEAN,
- )
- record_metric(
- "policy_perf/update_weights/max_pending_requests",
- len(curr_requests),
- Reduce.MAX,
- )
- logger.debug(f"Waiting for {len(curr_requests)} pending requests")
-
- # Wait until all pending requests have been processed
- # TODO: If generating long sequences, this might be long and will block
- # policy weight updates
- await self.request_lock.wait_for(lambda: len(self.requests) == 0)
-
- # Record weight update metrics
- record_metric("policy/update_weights/count_weight_updates", 1, Reduce.SUM)
-
- logger.debug(f"Starting weight update on {self.__class__.__name__}")
- if self.use_vllm_builtin_load:
- await self.policy_worker.update.call(version=policy_version)
- else:
- await self.policy_worker.update_DEPRECATED.call(version=policy_version)
- self.policy_version = policy_version
-
- # After updating the weights, we need to reset the KV cache
- self.scheduler.kv_cache_manager.reset_prefix_cache()
-
- # Resume accepting requests and wake up any waiting generate() calls
- async with self.request_lock:
- self.accepting_requests = True
- self.request_lock.notify_all()
-
- logger.info(f"Weight update completed (now v{self.policy_version})")
-
- @endpoint
- async def update_weights_DEPRECATED(self, policy_version: int): # noqa: N802
- # TODO: If generating long sequences, this might be long and will block policy weight updates
- curr_requests = [fut for _, fut in self.requests.values()]
- if curr_requests:
- logger.debug(f"Waiting for {len(curr_requests)} pending requests")
- await asyncio.gather(*curr_requests)
-
- await self.policy_worker.update_DEPRECATED.call(version=policy_version)
- self.policy_version = policy_version
- logger.info(f"Weight update completed (now v{self.policy_version})")
-
- @endpoint
- async def get_version(self) -> int:
- """Get the current policy version."""
- return self.policy_version
-
- @endpoint
- async def stop(self):
- self.running = False
-
- @endpoint
- async def _test_save_model_params(self):
- """Save model parameters before weight update, used for tesing purposes only."""
- logger.info("[Policy] save model parameters for testing.")
- await self.policy_worker._test_save_model_params.call()
-
- @endpoint
- async def _test_validate_model_params(self, validate_fn):
- """Validate updated model params using validate_fn."""
- logger.info("[Policy] start validating model parameters.")
- return await self.policy_worker._test_validate_model_params.call(validate_fn)
-
- def _to_completions(self, request_output: RequestOutput) -> list[Completion]:
- """Convert a RequestOutput to a list of Completion objects."""
- completions = []
- original_prompt = request_output.prompt
- prompt_token_ids = request_output.prompt_token_ids
- for output in request_output.outputs:
- completions.append(
- Completion(
- # TODO: the to_prompt encoding will be different from the original.
- # This is okay for now, since I don't see any direct usage of prompt using completion object.
- prompt=to_prompt(original_prompt),
- stop_reason=output.finish_reason,
- text=output.text,
- prompt_ids=torch.tensor(prompt_token_ids),
- token_ids=torch.tensor(output.token_ids),
- logprobs=self._extract_logprobs(output),
- generator_version=self.policy_version,
- )
- )
-
- return completions
-
- def _extract_logprobs(self, one_sample: CompletionOutput) -> torch.Tensor | None:
- """
- Extract log probabilities from a sample, if available.
- """
- if one_sample.logprobs is not None:
- return torch.tensor(
- [
- top_k_dict[token].logprob
- for token, top_k_dict in zip(
- one_sample.token_ids, one_sample.logprobs
- )
- ]
- )
- return None
-
-
-@dataclass
-class PolicyWorker(ForgeActor):
- vllm_config: VllmConfig
- state_dict_key: str = "model_state_dict"
- # TODO: remove this later since no plumbing exists to change this value.
- # Also, whether to use dcp or not can be inferred from torchstore get() call.
- use_dcp: bool = True
-
- # used for tesing purposes only
- _test_prev_params = {}
-
- def __post_init__(self):
- super().__init__()
-
- @endpoint
- async def setup(self):
- # TODO: remove ["gpus"] when monarch implements a flat rank
- self.rank = current_rank()["gpus"]
- self.worker = self.setup_worker()
-
- @endpoint
- async def execute_model(self, schedule: SchedulerOutput):
- return self.worker.execute_model(schedule)
-
- async def _load_tensor_parallel_state_dict(
- self, current_state_dict: dict, version: int
- ):
- """
- Load full state dict from torchstore into tensor parallel model with deterministic sharding.
- """
- sharding = VLLMSharding(
- self.vllm_config.parallel_config.tensor_parallel_size, self.rank
- )
-
- checkpoint_id = f"{self.state_dict_key}{DELIM}{version}"
- dcp_metadata = None
- if self.use_dcp:
- dcp_metadata = await ts.get(checkpoint_id)
-
- for param_name in current_state_dict.keys():
- current_tensor = current_state_dict[param_name]
-
- # Load the full tensor from torchstore
- # TODO: only get the part of the tensor that is needed
- if self.use_dcp:
- tensor_meta = dcp_metadata.state_dict_metadata[param_name]
- stored_tensor = torch.empty(
- size=tensor_meta.size, dtype=tensor_meta.properties.dtype
- )
- dcp.load(
- checkpoint_id=checkpoint_id, state_dict={param_name: stored_tensor}
- )
- else:
- stored_tensor = await ts.get(f"{checkpoint_id}{DELIM}{param_name}")
- sharding.load_from_source_to_target(
- param_name,
- stored_tensor,
- current_tensor,
- )
-
- @endpoint
- async def update_DEPRECATED(self, version: int): # noqa: N802
- """Update model weights by reading state dict from torchstore.
- Deprecated. This uses manual sharding logic which is buggy."""
- key = f"{self.state_dict_key}{DELIM}{version}"
- model = self.worker.model_runner.model
- current_state_dict = model.state_dict()
- start = time.perf_counter()
- await self._load_tensor_parallel_state_dict(current_state_dict, version)
- logger.info(
- f"Loaded state dict from {key} in {time.perf_counter() - start} seconds"
- )
-
- @endpoint
- async def update(self, version: int):
- """Update model weights by reading state dict from torchstore"""
- logger.info(
- f"[PolicyWorker::update] start updating weights to version {version}"
- )
- model = self.worker.model_runner.model
- prefix = get_param_prefix(version)
- logger.debug(f"{prefix=}")
- matching_keys = await ts.keys(prefix)
- logger.debug(f"{matching_keys=}")
- dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version)
- loaded_weights = set()
- t = Tracer("policy_worker_perf/update", timer="gpu")
- t.start()
- # Entire state dict is stored in a single DCP handle
- if dcp_whole_state_dict_key in matching_keys:
- logger.info(
- f"Loading {dcp_whole_state_dict_key} from DCP with handle {dcp_whole_state_dict_key}"
- )
- dcp_handle = await ts.get(dcp_whole_state_dict_key)
- hf_param_names = dcp_handle.param_names
- for name in hf_param_names:
- param = load_tensor_from_dcp(dcp_handle, name)
- loaded = model.load_weights([(name, param)])
- del param
- loaded_weights.update(loaded)
- else: # Load each parameter from torchstore directly without DCP
- hf_param_names = [extract_param_name(key) for key in matching_keys]
- # We can't pass a generator since vllm load_weights is not async.
- # Instead, we just call load_weights with one parameter at a time.
- for name in hf_param_names:
- param_key = get_param_key(version, name)
- param = await ts.get(param_key)
- loaded = model.load_weights([(name, param)])
- del param
- loaded_weights.update(loaded)
- t.stop()
- logger.debug(f"[PolicyWorker::update] Loaded weights: {loaded_weights}")
-
- @endpoint
- async def setup_kv_cache(self):
- """Based on vllm/v1/engine/core.py:EngineCore._initialize_kv_caches
- TODO: test that fails if vllm method updates
- """
- kv_cache_spec = self.worker.get_kv_cache_spec()
- if kv_cache_spec is not None:
- available_gpu_memory = self.worker.determine_available_memory()
- else:
- # Attention free models don't need memory for kv cache
- available_gpu_memory = 0
-
- # Get the kv cache tensor size
- kv_cache_config = get_kv_cache_config(
- self.vllm_config, kv_cache_spec, available_gpu_memory
- )
- # TODO: unify configs across TorchStore
- # unify_kv_cache_configs(kv_cache_configs)
- self.vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks
- self.vllm_config.cache_config.num_cpu_blocks = 0
-
- # Initialize kv cache and warmup the execution:
- # from multiproc_executor.py:MultiprocExecutor.initialize_from_config
- kv_cache_configs = [None] * self.vllm_config.parallel_config.world_size
- kv_cache_configs[self.rank] = kv_cache_config
- self.worker.initialize_from_config(kv_cache_configs)
- self.worker.compile_or_warm_up_model()
- self.worker.initialize_cache(kv_cache_config.num_blocks, 0)
- return kv_cache_config
-
- @endpoint
- async def _test_save_model_params(self):
- """Save model parameters before weight update, used for tesing purposes only."""
- logger.info("[PolicyWorker] save model parameters for testing.")
- for name, param in self.worker.model_runner.model.named_parameters():
- self._test_prev_params[name] = param.detach().cpu()
- logger.info(
- "[PolicyWorker] finished saving model parameters, len = %d",
- len(self._test_prev_params),
- )
-
- @endpoint
- async def _test_validate_model_params(self, validate_fn):
- """Validate updated model params using validate_fn."""
- logger.info("[PolicyWorker] start validating model parameters.")
- return validate_fn(
- self._test_prev_params, self.worker.model_runner.model, logger
- )
-
- def setup_worker(self):
- """Build and Instantiate vLLM worker"""
- parallel_config = self.vllm_config.parallel_config
- set_multiprocessing_worker_envs(parallel_config)
- ip, port = os.getenv("MASTER_ADDR"), os.getenv("MASTER_PORT")
- distributed_init_method = get_distributed_init_method(ip, port)
- all_kwargs = [{}] * parallel_config.world_size
- local_rank = self.rank % torch.accelerator.device_count()
- is_driver_worker = self.rank % parallel_config.tensor_parallel_size == 0
- all_kwargs[self.rank] = {
- "vllm_config": self.vllm_config,
- "local_rank": local_rank,
- "rank": self.rank,
- "distributed_init_method": distributed_init_method,
- "is_driver_worker": is_driver_worker,
- }
- worker = WorkerWrapperBase(self.vllm_config, self.rank)
- worker.init_worker(all_kwargs)
- worker.init_device()
- worker.load_model()
- return worker
-
-
-def convert_input(prompt=None, prompt_token_ids=None) -> dict:
- assert (prompt is None) ^ (prompt_token_ids is None)
- if prompt is not None:
- return {"prompt": prompt}
- return {"prompt_token_ids": prompt_token_ids}
-
-
-def get_default_sampling_params(vllm_config, overrides=None) -> SamplingParams:
- default_params = vllm_config.model_config.get_diff_sampling_param()
- if overrides is not None:
- default_params |= overrides
- if default_params:
- params = SamplingParams.from_optional(**default_params)
- else:
- params = SamplingParams()
- # We only care about the final output
- params.output_kind = RequestOutputKind.FINAL_ONLY
- return params
diff --git a/src/forge/actors/reference_model.py b/src/forge/actors/reference_model.py
index cc57e5246..99e25f937 100644
--- a/src/forge/actors/reference_model.py
+++ b/src/forge/actors/reference_model.py
@@ -13,11 +13,18 @@
from dataclasses import dataclass, field, fields
import torch
+
+from forge.controller import ForgeActor
+from forge.observability.metrics import record_metric, Reduce
+from forge.observability.perf_tracker import Tracer
+from forge.util.ops import compute_logprobs
from monarch.actor import current_rank, current_size, endpoint
from torch.distributed.tensor import DTensor
+from torch.distributed.tensor.parallel import loss_parallel
from torchtitan.config.job_config import (
Checkpoint,
+ Comm,
Compile,
Model,
Parallelism,
@@ -26,22 +33,46 @@
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig
-from forge.controller import ForgeActor
-from forge.observability.metrics import record_metric, Reduce
-from forge.observability.perf_tracker import Tracer
-from forge.util.ops import compute_logprobs
-
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
@dataclass
class ReferenceModel(ForgeActor):
+ """
+ A reference model actor for reinforcement learning (RL) training.
+
+ Based on TorchTitan's engine architecture, this actor provides a
+ frozen model that only runs forward passes without gradient
+ computation. It is typically used to maintain algorithmic
+ consistency in policy optimization methods such as GRPO
+ (Group Relative Policy Optimization) or PPO (Proximal Policy
+ Optimization), where it serves as a fixed reference point to
+ compute KL divergence penalties against the training policy.
+
+ The reference model is loaded from a checkpoint and runs in
+ evaluation mode with inference_mode enabled to optimize memory and
+ compute efficiency.
+
+ Attributes:
+
+ model (Model): Model configuration (architecture, vocab size,
+ etc.)
+ parallelism (Parallelism): Parallelism strategy configuration
+ (TP, PP, CP, DP)
+ checkpoint (Checkpoint): Checkpoint loading configuration
+ compile (Compile): Torch compilation settings
+ comm (Comm): Communication backend configuration
+ training (Training): Training-related settings (dtype, garbage
+ collection, etc.)
+ """
+
# Refer to titan JobConfig for enabling more ForgeEngine configuration
model: Model = field(default_factory=Model)
parallelism: Parallelism = field(default_factory=Parallelism)
checkpoint: Checkpoint = field(default_factory=Checkpoint)
compile: Compile = field(default_factory=Compile)
+ comm: Comm = field(default_factory=Comm)
training: Training = field(
default_factory=Training
) # Needed in order to set attrs like dtype, garbage collection freq, etc.
@@ -68,6 +99,10 @@ def __post_init__(self):
self.rank = current_rank().rank
self.size = math.prod(current_size().values())
+ self.compute_log_probs = compute_logprobs
+ if self.compile.enable:
+ self.compute_log_probs = torch.compile(self.compute_log_probs)
+
env = {
"RANK": str(self.rank),
"LOCAL_RANK": str(self.rank),
@@ -85,7 +120,11 @@ def __post_init__(self):
@endpoint
async def setup(self):
engine_config = {f.name: getattr(self, f.name) for f in fields(self)}
- self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
+ engine_config = ForgeJobConfig(**engine_config)
+ engine_config.checkpoint.folder = (
+ "" # hardcode to empty to force load from initial_load_path
+ )
+ self.engine = ForgeEngine(engine_config)
self.engine.checkpointer.load()
self.model = self.engine.model_parts[0] # No pipeline parallelism yet
self.model.eval()
@@ -98,7 +137,7 @@ async def forward(
Args:
input_ids (torch.Tensor): input token ids with shape [group_size, req + res length].
max_req_tokens (int): maximum request length.
- return_logprobs (bool): whether to return og probabilities instead of raw logits.
+ return_logprobs (bool): whether to return log probabilities instead of raw logits.
return_logprobs flag significantly impacts the amount of data transferred to the caller:
- When False: Returns logits with shape [group_size, req + res_length, vocab_size].
@@ -110,21 +149,15 @@ async def forward(
"""
# Record reference model metrics
record_metric("reference_perf/forward/count_forward_passes", 1, Reduce.SUM)
- record_metric(
- "reference_perf/forward/avg_sequence_length",
- input_ids.shape[1],
- Reduce.MEAN,
- )
t = Tracer("reference_perf/forward", timer="gpu", track_memory=True)
t.start()
self.engine.gc_handler.run(self.step)
- t.step("garbage_collection")
model_parts = self.engine.model_parts
parallel_dims = self.engine.parallel_dims
input_ids = input_ids.to("cuda")
- t.step("to_device")
+
# optional_context_parallel_ctx = (
# dist_utils.create_context_parallel_ctx(
# cp_mesh=parallel_dims.world_mesh["cp"],
@@ -146,15 +179,23 @@ async def forward(
with torch.inference_mode():
logits = self.model(input_ids)
self.step += 1
- if isinstance(logits, DTensor):
- logits = logits.full_tensor()
- t.step("forward")
if not return_logprobs:
+ if isinstance(logits, DTensor):
+ logits = logits.full_tensor()
t.stop()
return logits
else:
- logprobs = compute_logprobs(logits, input_ids[:, max_req_tokens:])
- t.step("compute_logprobs")
+ response_tokens = input_ids[:, max_req_tokens:]
+ if parallel_dims.tp_enabled and isinstance(logits, DTensor):
+ with loss_parallel():
+ logprobs = self.compute_log_probs(logits, response_tokens)
+
+ # loss_parallel produces Replicated output - to_local() returns the full tensor
+ logprobs = logprobs.to_local()
+ else:
+ if isinstance(logits, DTensor):
+ logits = logits.full_tensor()
+ logprobs = self.compute_log_probs(logits, response_tokens)
t.stop()
return logprobs
diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py
index a92fd5501..f62f8d550 100644
--- a/src/forge/actors/replay_buffer.py
+++ b/src/forge/actors/replay_buffer.py
@@ -6,105 +6,138 @@
import logging
import random
+from collections import deque
from dataclasses import dataclass
+from operator import itemgetter
from typing import Any, Callable
-from monarch.actor import endpoint
-
from forge.controller import ForgeActor
from forge.observability.metrics import record_metric, Reduce
-from forge.observability.perf_tracker import trace
+
+from monarch.actor import endpoint
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
+@dataclass
+class BufferEntry:
+ data: "Episode"
+ sample_count: int = 0
+
+
+def age_evict(
+ buffer: deque, policy_version: int, max_samples: int = None, max_age: int = None
+) -> list[int]:
+ """Buffer eviction policy, remove old or over-sampled entries"""
+ indices = []
+ for i, entry in enumerate(buffer):
+ if max_age is not None and policy_version - entry.data.policy_version > max_age:
+ continue
+ if max_samples is not None and entry.sample_count >= max_samples:
+ continue
+ indices.append(i)
+ return indices
+
+
+def random_sample(buffer: deque, sample_size: int, policy_version: int) -> list[int]:
+ """Buffer random sampling policy"""
+ if sample_size > len(buffer):
+ return None
+ return random.sample(range(len(buffer)), k=sample_size)
+
+
@dataclass
class ReplayBuffer(ForgeActor):
"""Simple in-memory replay buffer implementation."""
batch_size: int
- max_policy_age: int
dp_size: int = 1
+ max_policy_age: int | None = None
+ max_buffer_size: int | None = None
+ max_resample_count: int | None = 0
seed: int | None = None
collate: Callable = lambda batch: batch
-
- def __post_init__(self):
- super().__init__()
+ eviction_policy: Callable = age_evict
+ sample_policy: Callable = random_sample
@endpoint
async def setup(self) -> None:
- self.buffer: list = []
+ self.buffer: deque = deque(maxlen=self.max_buffer_size)
if self.seed is None:
self.seed = random.randint(0, 2**32)
random.seed(self.seed)
- self.sampler = random.sample
@endpoint
async def add(self, episode: "Episode") -> None:
- self.buffer.append(episode)
+ self.buffer.append(BufferEntry(episode))
record_metric("buffer/add/count_episodes_added", 1, Reduce.SUM)
@endpoint
- @trace("buffer_perf/sample", track_memory=False)
async def sample(
- self, curr_policy_version: int, batch_size: int | None = None
+ self, curr_policy_version: int
) -> tuple[tuple[Any, ...], ...] | None:
"""Sample from the replay buffer.
Args:
curr_policy_version (int): The current policy version.
- batch_size (int, optional): Number of episodes to sample. If none, defaults to batch size
- passed in at initialization.
Returns:
A list of sampled episodes with shape (dp_size, bsz, ...) or None if there are not enough episodes in the buffer.
"""
- # Record sample request metric
- record_metric("buffer/sample/count_sample_requests", 1, Reduce.SUM)
- bsz = batch_size if batch_size is not None else self.batch_size
- total_samples = self.dp_size * bsz
+ total_samples = self.dp_size * self.batch_size
- # Evict old episodes
+ # Evict episodes
self._evict(curr_policy_version)
- if total_samples > len(self.buffer):
- return None
-
- # Calculate buffer utilization
- utilization_pct = (
- (total_samples / len(self.buffer)) * 100 if len(self.buffer) > 0 else 0
- )
-
- record_metric(
- "buffer/sample/avg_buffer_utilization",
- len(self.buffer),
- Reduce.MEAN,
- )
-
- record_metric(
- "buffer/sample/avg_buffer_utilization_pct",
- utilization_pct,
- Reduce.MEAN,
- )
+ # Calculate metrics
+ if len(self.buffer) > 0:
+ record_metric(
+ "buffer/sample/demand_to_size_ratio",
+ total_samples / len(self.buffer),
+ Reduce.MEAN,
+ )
+ if self.max_buffer_size:
+ record_metric(
+ "buffer/sample/avg_buffer_utilization",
+ len(self.buffer) / self.max_buffer_size,
+ Reduce.MEAN,
+ )
# TODO: prefetch samples in advance
- idx_to_sample = self.sampler(range(len(self.buffer)), k=total_samples)
- # Pop episodes in descending order to avoid shifting issues
- popped = [self.buffer.pop(i) for i in sorted(idx_to_sample, reverse=True)]
-
- # Reorder popped episodes to match the original random sample order
- sorted_idxs = sorted(idx_to_sample, reverse=True)
- idx_to_popped = dict(zip(sorted_idxs, popped))
- sampled_episodes = [idx_to_popped[i] for i in idx_to_sample]
-
+ sampled_indices = self.sample_policy(
+ self.buffer, total_samples, curr_policy_version
+ )
+ if sampled_indices is None:
+ return None
+ sampled_episodes = []
+ for entry in self._collect(sampled_indices):
+ entry.sample_count += 1
+ sampled_episodes.append(entry.data)
+
+ # Calculate and record policy age metrics for sampled episodes
+ sampled_policy_ages = [
+ curr_policy_version - ep.policy_version for ep in sampled_episodes
+ ]
+ if sampled_policy_ages:
+ record_metric(
+ "buffer/sample/avg_sampled_policy_age",
+ sum(sampled_policy_ages) / len(sampled_policy_ages),
+ Reduce.MEAN,
+ )
+ record_metric(
+ "buffer/sample/max_sampled_policy_age",
+ max(sampled_policy_ages),
+ Reduce.MAX,
+ )
# Reshape into (dp_size, bsz, ...)
reshaped_episodes = [
- sampled_episodes[dp_idx * bsz : (dp_idx + 1) * bsz]
+ sampled_episodes[dp_idx * self.batch_size : (dp_idx + 1) * self.batch_size]
for dp_idx in range(self.dp_size)
]
+ # Call the underlying collate function to collate the episodes into a batch
return self.collate(reshaped_episodes)
@endpoint
@@ -117,46 +150,53 @@ async def evict(self, curr_policy_version: int) -> None:
"""
self._evict(curr_policy_version)
- def _evict(self, curr_policy_version: int) -> None:
+ def _evict(self, curr_policy_version):
buffer_len_before_evict = len(self.buffer)
- self.buffer = [
- trajectory
- for trajectory in self.buffer
- if (curr_policy_version - trajectory.policy_version) <= self.max_policy_age
- ]
- buffer_len_after_evict = len(self.buffer)
-
- # Record evict metrics
- policy_staleness = [
- curr_policy_version - ep.policy_version for ep in self.buffer
- ]
- if policy_staleness:
- record_metric(
- "buffer/evict/avg_policy_staleness",
- sum(policy_staleness) / len(policy_staleness),
- Reduce.MEAN,
- )
- record_metric(
- "buffer/evict/max_policy_staleness",
- max(policy_staleness),
- Reduce.MAX,
- )
+ indices = self.eviction_policy(
+ self.buffer,
+ curr_policy_version,
+ self.max_resample_count + 1,
+ self.max_policy_age,
+ )
+ self.buffer = deque(self._collect(indices))
- # Record eviction metrics
- evicted_count = buffer_len_before_evict - buffer_len_after_evict
- if evicted_count > 0:
- record_metric(
- "buffer/evict/sum_episodes_evicted", evicted_count, Reduce.SUM
- )
+ evicted_count = buffer_len_before_evict - len(self.buffer)
+ record_metric("buffer/evict/sum_episodes_evicted", evicted_count, Reduce.SUM)
logger.debug(
f"maximum policy age: {self.max_policy_age}, current policy version: {curr_policy_version}, "
- f"{evicted_count} episodes expired, {buffer_len_after_evict} episodes left"
+ f"{evicted_count} episodes expired, {len(self.buffer)} episodes left"
)
+ def _collect(self, indices: list[int]):
+ """Efficiently traverse deque and collect elements at each requested index"""
+ n = len(self.buffer)
+ if n == 0 or len(indices) == 0:
+ return []
+
+ # Normalize indices and store with their original order
+ indexed = [(pos, idx % n) for pos, idx in enumerate(indices)]
+ indexed.sort(key=itemgetter(1))
+
+ result = [None] * len(indices)
+ rotations = 0 # logical current index
+ total_rotation = 0 # total net rotation applied
+
+ for orig_pos, idx in indexed:
+ move = idx - rotations
+ self.buffer.rotate(-move)
+ total_rotation += move
+ rotations = idx
+ result[orig_pos] = self.buffer[0]
+
+ # Restore original deque orientation
+ self.buffer.rotate(total_rotation)
+
+ return result
+
@endpoint
async def _getitem(self, idx: int):
- return self.buffer[idx]
+ return self.buffer[idx].data
@endpoint
async def _numel(self) -> int:
diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py
deleted file mode 100644
index 4ffc63001..000000000
--- a/src/forge/actors/trainer.py
+++ /dev/null
@@ -1,521 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-import logging
-import math
-import os
-import shutil
-
-import time
-from collections.abc import Mapping
-from dataclasses import dataclass, field, fields
-from typing import Callable
-
-import torch
-import torch.distributed.checkpoint as dcp
-import torchstore as ts
-
-from monarch.actor import current_rank, current_size, endpoint
-from torch import Tensor
-from torch.distributed.checkpoint._nested_dict import flatten_state_dict
-from torchstore.state_dict_utils import DELIM
-from torchtitan.config.job_config import (
- ActivationCheckpoint,
- Checkpoint,
- Comm,
- Compile,
- Job,
- LRScheduler,
- MemoryEstimation,
- Model,
- Optimizer,
- Parallelism,
- Quantize,
- Training,
-)
-from torchtitan.experiments.forge.engine import ForgeEngine
-from torchtitan.experiments.forge.job_config import ForgeJobConfig
-
-from forge.actors._torchstore_utils import (
- DcpHandle,
- get_dcp_whole_state_dict_key,
- get_param_key,
-)
-
-from forge.controller import ForgeActor
-from forge.data.utils import batch_to_device
-from forge.observability.metrics import record_metric, Reduce
-from forge.observability.perf_tracker import Tracer
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.DEBUG)
-
-
-def cleanup_old_weight_versions(
- state_dict_key: str,
- delim: str,
- current_policy_version: int,
-) -> None:
- """Delete old weight versions, keeping only current and N-1 versions.
-
- TODO - issues/194: provide a more robust way to handle eviction.
-
- Args:
- state_dict_key: The base key for state dict storage
- delim: The delimiter used between key and version
- current_policy_version: The current policy version to keep
- """
- if current_policy_version <= 1:
- return # No cleanup needed for versions 0 or 1
-
- prefix = f"{state_dict_key}{delim}"
- current_weights = f"{prefix}{current_policy_version}"
- previous_weights = f"{prefix}{current_policy_version - 1}"
-
- # Find all weight directories that match our pattern
- parent_dir = os.path.dirname(prefix) or "."
- if os.path.exists(parent_dir):
- for item in os.listdir(parent_dir):
- item_path = os.path.join(parent_dir, item)
- if (
- item.startswith(os.path.basename(prefix))
- and item != os.path.basename(current_weights)
- and item != os.path.basename(previous_weights)
- and os.path.isdir(item_path)
- ):
- try:
- shutil.rmtree(item_path, ignore_errors=True)
- logger.debug(f"Removed old weights at {item_path}")
- except OSError as e:
- logger.debug(f"Error deleting {item_path}: {e}")
-
-
-@dataclass
-class RLTrainer(ForgeActor):
- job: Job = field(default_factory=Job)
- model: Model = field(default_factory=Model)
- optimizer: Optimizer = field(default_factory=Optimizer)
- lr_scheduler: LRScheduler = field(default_factory=LRScheduler)
- training: Training = field(default_factory=Training)
- parallelism: Parallelism = field(default_factory=Parallelism)
- checkpoint: Checkpoint = field(default_factory=Checkpoint)
- activation_checkpoint: ActivationCheckpoint = field(
- default_factory=ActivationCheckpoint
- )
- compile: Compile = field(default_factory=Compile)
- quantize: Quantize = field(default_factory=Quantize)
- comm: Comm = field(default_factory=Comm)
- memory_estimation: MemoryEstimation = field(default_factory=MemoryEstimation)
- # Non JobConfig-related fields
- loss: Callable = lambda logits, **targets: logits
- state_dict_key: str = "model_state_dict"
- use_dcp: bool = True
- dcp_path: str = "forge_dcp_tmp"
- vllm_tp_DEPRECATED: int = 1 # noqa: N815
- use_vllm_builtin_load: bool = True
-
- def __post_init__(self):
- """Initializes config types and env variables.
-
- torchrun normally hands env variables, but we need to do it ourselves
- in monarch for now.
-
- """
- super().__init__()
-
- if self.use_dcp:
- # DCP specific optimization
- torch.serialization.set_crc32_options(False)
-
- # Instantiate dict fields
- for f in fields(self):
- attr = getattr(self, f.name)
- if isinstance(attr, Mapping):
- setattr(self, f.name, f.type(**attr))
- elif not isinstance(attr, f.type):
- raise TypeError(
- f"{f.name} should be a {f.type} type or a dict like object"
- )
-
- self.step = 1 # fragile contract.
- self.num_training_steps = self.training.steps
- self.gradient_accumulation_steps = 1
- self.rank = current_rank().rank
- self.size = math.prod(current_size().values())
-
- env = {
- "RANK": str(self.rank),
- "LOCAL_RANK": str(self.rank),
- "LOCAL_WORLD_SIZE": str(self.size),
- "GROUP_RANK": str(self.size),
- "GROUP_WORLD_SIZE": str(self.size),
- "ROLE_RANK": str(self.rank),
- "ROLE_WORLD_SIZE": str(self.size),
- "ROLE_NAME": "rank",
- "WORLD_SIZE": str(self.size),
- "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
- }
- os.environ.update(env)
-
- @endpoint
- async def setup(self):
- # TODO: update ForgeEngine to not use ForgeJobConfig
- engine_config = {f.name: getattr(self, f.name) for f in fields(self)}
- for key in {
- "loss",
- "state_dict_key",
- "use_dcp",
- "use_vllm_builtin_load",
- "dcp_path",
- "vllm_tp_DEPRECATED",
- }:
- engine_config.pop(key) # Not part of job config
- self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
- self.engine.checkpointer.load(step=self.step)
- self.engine.optimizers.zero_grad()
-
- def forward_backward(
- self, inputs: dict[str, Tensor], targets: dict[str, Tensor]
- ) -> Tensor:
- model_parts = self.engine.model_parts
- parallel_dims = self.engine.parallel_dims
-
- # apply context parallelism if cp is enabled
- # ensure CP handles the separate freqs_cis buffer for each pp stage
- # if getattr(self.engine.model_args, "use_flex_attn", False):
- # cp_mesh = (
- # parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None
- # )
- # init_attention_mask(
- # inputs, self.engine.tokenizer.base_tokenizer.eos_id, cp_mesh
- # )
-
- # optional_context_parallel_ctx = (
- # dist_utils.create_context_parallel_ctx(
- # cp_mesh=parallel_dims.world_mesh["cp"],
- # cp_buffers=[inputs, targets] + [m.freqs_cis for m in model_parts],
- # cp_seq_dims=[1, 1] + [0 for _ in model_parts],
- # cp_no_restore_buffers={inputs, targets},
- # cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
- # )
- # if parallel_dims.cp_enabled
- # else None
- # )
- optional_context_parallel_ctx = None
-
- if parallel_dims.pp_enabled:
- raise NotImplementedError("PP not implemented yet")
- # TODO implement PP
- # # Pipeline Parallel forward / backward inside step() call
- # with self.train_context(optional_context_parallel_ctx):
- # targets, losses = (
- # (labels, []) if self.pp_has_last_stage else (None, None)
- # )
- # if self.pp_has_first_stage:
- # self.pp_schedule.step(
- # inputs, target=targets, losses=losses, input_batch=inputs
- # )
- # else:
- # self.pp_schedule.step(
- # target=targets, losses=losses, input_batch=inputs
- # )
- #
- # # accumulate losses across pipeline microbatches
- # # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
- # loss = (
- # torch.mean(torch.stack(losses)).to(self.device)
- # if self.pp_has_last_stage
- # else torch.tensor([-1.0], device=self.device)
- # )
- else:
- # Non-PP forward / backward
- with self.engine.train_context(optional_context_parallel_ctx):
- assert len(model_parts) == 1
- with self.engine.maybe_enable_amp:
- logits = model_parts[0](**inputs)
- loss = self.loss(logits, **targets)
- # need to free to before bwd to avoid peaking memory
- del logits
- loss.backward()
-
- return loss
-
- @endpoint
- async def train_step(
- self, inputs: list[dict[str, Tensor]], targets: list[dict[str, Tensor]]
- ) -> float:
-
- # Log timesteps
- t = Tracer("rl_trainer_perf/step", timer="gpu", track_memory=True)
- t.start()
-
- self.engine.gc_handler.run(self.step)
- local_inputs = inputs[self.engine.dp_rank]
- local_targets = targets[self.engine.dp_rank]
- batch_to_device(local_inputs, self.engine.device)
- batch_to_device(local_targets, self.engine.device)
- # compute policy logprobs
- # TODO implement gradient accumulation
- # with GradientAccumulation(
- # self.gradient_accumulation_steps,
- # self.model,
- # self.data_parallel_size,
- # ) as grad_acc:
- loss = self.forward_backward(local_inputs, local_targets)
- torch.distributed.all_reduce(loss)
- t.step("forward_backward")
-
- # Get learning rate from scheduler
- current_lr = (
- self.engine.lr_schedulers.get_last_lr()[0]
- if hasattr(self.engine.lr_schedulers, "get_last_lr")
- else 0.001
- )
- record_metric("rl_trainer/learning_rate", current_lr, Reduce.MIN)
-
- self.engine.optimizers.step()
- self.engine.optimizers.zero_grad()
- self.engine.lr_schedulers.step()
- t.step("optimizer_step")
-
- # Record training metrics
- # TODO: delete item() to avoid cpu-gpu sync
- loss = loss.detach().cpu().item()
- record_metric("rl_trainer/count_training_steps", 1, Reduce.SUM)
- record_metric("rl_trainer/avg_grpo_loss", loss, Reduce.MEAN)
-
- # TODO: Extract actual KL divergence and policy entropy from the loss computation
- # These are placeholder values until the loss function exposes these metrics
- # record_metric("rl_trainer/step/avg_kl_divergence", 0.0, Reduce.MEAN)
- # record_metric("rl_trainer/step/std_kl_divergence", 0.0, Reduce.STD)
- # record_metric("rl_trainer/step/avg_policy_entropy", 0.0, Reduce.MEAN)
-
- self.step += 1
- self.engine.checkpointer.save(
- curr_step=self.step,
- last_step=self.step == self.num_training_steps,
- )
- t.step("save_checkpoint")
- t.stop()
- return loss
-
- @endpoint
- async def push_weights_DEPRECATED( # noqa: N802
- self, policy_version: int, vllm_tp_DEPRECATED: int = 1
- ) -> None: # noqa: N802
- """[Deprecated] This method pushes weights to torchstore in the vllm format,
- which is buggy and not scalable to other models.
- Deprecated in favor of push_weights."""
- return await self._push_weights_DEPRECATED(policy_version, vllm_tp_DEPRECATED)
-
- async def _push_weights_DEPRECATED( # noqa: N802
- self, policy_version: int, vllm_tp_DEPRECATED: int
- ) -> None: # noqa: N802
- # Save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now.
- # TODO:
- # 1. Checkpoint invokes state-dict flattening during dcp_save for [MODEL].
- # May need to replicate the same in this code path.
- # 2. Unify CheckpointManager and TorchStore weights save control path.
- if "model" not in self.engine.checkpointer.states:
- raise RuntimeError("Model state not found in checkpointer state")
-
- sd = self.engine.checkpointer.states["model"].state_dict()
- flattened_state_dict, _ = flatten_state_dict(sd)
-
- if self.engine.checkpointer.sd_adapter is None:
- raise RuntimeError(
- "Trying to save checkpoint in HF safetensors format, but sd_adapter is not provided."
- )
- hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict)
-
- # TODO: Figure out how to gracefully handle which model to-vLLM conversion is needed
- vllm_ready_hf_sd = _qwen3_hf_to_vllm(
- sd=hf_state_dict,
- num_layers=self.engine.model_args.n_layers,
- vllm_tp=vllm_tp_DEPRECATED,
- )
-
- key = f"{self.state_dict_key}{DELIM}{policy_version}"
- if self.use_dcp:
- # TODO - DCP should probably be being saved to NFS explicitly?
- # Right now it will only save everything locally
- storage_writer = torch.distributed.checkpoint.FileSystemWriter(
- key, single_file_per_rank=False, thread_count=8
- )
- metadata = dcp.save(
- storage_writer=storage_writer, state_dict=vllm_ready_hf_sd
- )
- await ts.put(key, metadata)
-
- # Delete old weight versions if they exist
- if self.rank == 0:
- cleanup_old_weight_versions(
- state_dict_key=self.state_dict_key,
- delim=DELIM,
- current_policy_version=policy_version,
- )
- else:
- await ts.put_state_dict(vllm_ready_hf_sd, key)
-
- @endpoint
- async def push_weights(self, policy_version: int) -> None:
- """Push weights to torchstore in HF format."""
- t = Tracer("rl_trainer_perf/push_weights", timer="gpu", track_memory=True)
- t.start()
- logger.info(f"Pushing weights for policy version {policy_version}")
- if not self.use_vllm_builtin_load:
- result = await self._push_weights_DEPRECATED(
- policy_version, self.vllm_tp_DEPRECATED
- )
- t.step("push_weights_DEPRECATED")
- return result
-
- start_time = time.perf_counter()
- if "model" not in self.engine.checkpointer.states:
- raise RuntimeError("Model state not found in checkpointer state")
-
- sd = self.engine.checkpointer.states["model"].state_dict()
- flattened_state_dict, _ = flatten_state_dict(sd)
- t.step("flatten_state_dict")
- if self.engine.checkpointer.sd_adapter is None:
- raise RuntimeError(
- "Trying to save checkpoint in HF safetensors format, but sd_adapter is not provided."
- )
- hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict)
- t.step("to_hf")
- if self.use_dcp:
- key = get_dcp_whole_state_dict_key(policy_version)
- dcp_id = f"{self.dcp_path}/{key}"
- storage_writer = torch.distributed.checkpoint.FileSystemWriter(
- dcp_id, single_file_per_rank=False, thread_count=8
- )
- metadata = dcp.save(storage_writer=storage_writer, state_dict=hf_state_dict)
- dcp_handle = DcpHandle(
- checkpoint_id=dcp_id,
- metadata=metadata,
- param_names=hf_state_dict.keys(),
- )
- await ts.put(key, dcp_handle)
- t.step("dcp_save")
- else:
- for name, param in hf_state_dict.items():
- key = get_param_key(policy_version, name)
- await ts.put(key, param)
- t.step("ts_save")
- t.stop()
- end_time = time.perf_counter()
- logger.info("Completed weights push in %.2f seconds", end_time - start_time)
-
- @endpoint
- async def cleanup(self) -> None:
- if self.engine.checkpointer:
- self.engine.checkpointer.close()
-
-
-def _shard_and_concat(sources: list[torch.Tensor], dim: int, tp: int) -> torch.Tensor:
- """Shard and concatenate tensors along a given dimension.
-
- Args:
- source (list[torch.Tensor]): List of tensors to shard and concatenate.
- dim (int): Dimension along which to shard and concatenate.
- tp (int): Number of tensor parallel groups.
-
- Returns:
- torch.Tensor: Concatenated tensor.
- """
- sharded_sources = []
- for source in sources:
- sharded_sources.append(torch.chunk(source, tp, dim=dim))
-
- combined_shards = []
- for shard_idx in range(tp):
- combined = torch.cat([s[shard_idx] for s in sharded_sources], dim=dim)
- combined_shards.append(combined)
- return torch.cat(combined_shards, dim=dim)
-
-
-def _qwen3_hf_to_vllm(
- sd: dict[str, torch.Tensor], num_layers: int, vllm_tp: int
-) -> dict[str, torch.Tensor]:
- """Convert transformers state dict to vLLM format. Specifically, this fuses
- QKV projection and MLP gate_up_proj layers.
-
- Args:
- sd (dict): State dict from HF model.
- num_layers (int): Number of layers in the model.
-
- Returns:
- dict: State dict in vLLM format.
- """
- load_sd = {}
-
- def unwrap(t):
- """Unwrap a DTensor to a Tensor."""
- return t.full_tensor() if isinstance(t, torch.distributed.tensor.DTensor) else t
-
- for key in sd.keys():
- sd[key] = unwrap(sd[key]).cpu()
-
- # Copy over directly mapped keys
- for k in sd:
- if any(
- x in k
- for x in [
- "down_proj",
- "input_layernorm",
- "post_attention_layernorm",
- "o_proj",
- "norm.weight",
- "embed_tokens.weight",
- "lm_head.weight",
- ]
- ):
- load_sd[k] = sd[k]
-
- for i in range(num_layers):
- prefix = f"model.layers.{i}."
- # QKV fusion
- q = sd[prefix + "self_attn.q_proj.weight"]
- k = sd[prefix + "self_attn.k_proj.weight"]
- v = sd[prefix + "self_attn.v_proj.weight"]
-
- load_sd[prefix + "self_attn.qkv_proj.weight"] = _shard_and_concat(
- [q, k, v], dim=0, tp=vllm_tp
- )
-
- # Untested: QKV fusion - handle bias if present
- q_bias_key = prefix + "self_attn.q_proj.bias"
- k_bias_key = prefix + "self_attn.k_proj.bias"
- v_bias_key = prefix + "self_attn.v_proj.bias"
-
- if all(key in sd for key in [q_bias_key, k_bias_key, v_bias_key]):
- q_bias = sd[q_bias_key]
- k_bias = sd[k_bias_key]
- v_bias = sd[v_bias_key]
- load_sd[prefix + "self_attn.qkv_proj.bias"] = _shard_and_concat(
- [q_bias, k_bias, v_bias], dim=0, tp=vllm_tp
- )
-
- # MLP gate_up_proj fusion
- gate = sd[prefix + "mlp.gate_proj.weight"]
- up = sd[prefix + "mlp.up_proj.weight"]
- load_sd[prefix + "mlp.gate_up_proj.weight"] = _shard_and_concat(
- [gate, up], dim=0, tp=vllm_tp
- )
-
- # Untested: MLP gate_up_proj fusion - handle bias if present
- gate_bias_key = prefix + "mlp.gate_proj.bias"
- up_bias_key = prefix + "mlp.up_proj.bias"
-
- if all(key in sd for key in [gate_bias_key, up_bias_key]):
- gate_bias = sd[gate_bias_key]
- up_bias = sd[up_bias_key]
- # Same sharding has to happen here
- load_sd[prefix + "mlp.gate_up_proj.bias"] = _shard_and_concat(
- [gate_bias, up_bias], dim=0, tp=vllm_tp
- )
-
- return load_sd
diff --git a/src/forge/actors/trainer/__init__.py b/src/forge/actors/trainer/__init__.py
new file mode 100644
index 000000000..8978ab76d
--- /dev/null
+++ b/src/forge/actors/trainer/__init__.py
@@ -0,0 +1,23 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import warnings
+
+from .titan import TitanTrainer
+
+__all__ = ["TitanTrainer", "RLTrainer"]
+
+
+def __getattr__(name):
+ if name == "RLTrainer":
+ warnings.warn(
+ "RLTrainer is deprecated and will be removed in a future version. "
+ "Please use TitanTrainer instead.",
+ FutureWarning,
+ stacklevel=2,
+ )
+ return TitanTrainer
+ raise AttributeError(f"module {__name__} has no attribute {name}")
diff --git a/src/forge/actors/trainer/titan.py b/src/forge/actors/trainer/titan.py
new file mode 100644
index 000000000..6a4fa3cb4
--- /dev/null
+++ b/src/forge/actors/trainer/titan.py
@@ -0,0 +1,234 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import os
+
+import time
+from collections.abc import Mapping
+from dataclasses import dataclass, field, fields
+from typing import Callable
+
+import torch
+import torch.distributed.checkpoint as dcp
+import torchstore as ts
+
+from forge.actors._torchstore_utils import (
+ DcpHandle,
+ get_dcp_whole_state_dict_key,
+ get_param_key,
+ rdma_available,
+)
+
+from forge.controller import ForgeActor
+from forge.data.utils import batch_to_device
+from forge.observability.metrics import record_metric, Reduce
+from forge.observability.perf_tracker import Tracer
+
+from monarch.actor import endpoint
+from torch import Tensor
+from torch.distributed.checkpoint._nested_dict import flatten_state_dict
+from torchtitan.config.job_config import (
+ ActivationCheckpoint,
+ Checkpoint,
+ Comm,
+ Compile,
+ Job,
+ LRScheduler,
+ MemoryEstimation,
+ Model,
+ Optimizer,
+ Parallelism,
+ Quantize,
+ Training,
+)
+from torchtitan.experiments.forge.engine import ForgeEngine
+from torchtitan.experiments.forge.job_config import ForgeJobConfig
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+
+
+@dataclass
+class TitanTrainer(ForgeActor):
+ """A generic trainer actor implementation built on top of TorchTitan.
+
+ Built on top of TorchTitan's training engine, this actor provides a complete training
+ loop for reinforcement learning. It performs forward and backward passes with gradient
+ computation, optimization steps, and checkpoint management. Unlike the ReferenceModel
+ actor which only runs forward passes, RLTrainer actively updates the policy model
+ parameters through gradient descent.
+
+ The trainer supports the same distributed training strategies that TorchTitan does,
+ including but not limited to, tensor parallelism, data parallelism, and FSDP
+ (Fully Sharded Data Parallel). It is typically used in conjunction with ReferenceModel
+ for policy optimization algorithms like GRPO (Group Relative Policy Optimization),
+ where it optimizes the policy against a loss that includes KL divergence penalties
+ from the reference model.
+
+ The trainer handles:
+ - Forward and backward propagation with automatic mixed precision (AMP)
+ - Optimizer steps with learning rate scheduling
+ """
+
+ job: Job = field(default_factory=Job)
+ model: Model = field(default_factory=Model)
+ optimizer: Optimizer = field(default_factory=Optimizer)
+ lr_scheduler: LRScheduler = field(default_factory=LRScheduler)
+ training: Training = field(default_factory=Training)
+ parallelism: Parallelism = field(default_factory=Parallelism)
+ checkpoint: Checkpoint = field(default_factory=Checkpoint)
+ activation_checkpoint: ActivationCheckpoint = field(
+ default_factory=ActivationCheckpoint
+ )
+ compile: Compile = field(default_factory=Compile)
+ quantize: Quantize = field(default_factory=Quantize)
+ comm: Comm = field(default_factory=Comm)
+ memory_estimation: MemoryEstimation = field(default_factory=MemoryEstimation)
+ # Non JobConfig-related fields
+ loss: Callable = lambda logits, **targets: logits
+ state_dict_key: str = "model_state_dict"
+ use_dcp: bool = not rdma_available()
+ dcp_path: str = "forge_dcp_tmp"
+
+ def __post_init__(self):
+ super().__init__()
+ if self.use_dcp:
+ torch.serialization.set_crc32_options(False)
+
+ for f in fields(self):
+ attr = getattr(self, f.name)
+ if isinstance(attr, Mapping):
+ setattr(self, f.name, f.type(**attr))
+ elif not isinstance(attr, f.type):
+ raise TypeError(
+ f"{f.name} should be a {f.type} type or a dict like object"
+ )
+
+ self.step = 1 # fragile contract.
+ self.num_training_steps = self.training.steps
+ self.gradient_accumulation_steps = 1
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
+ logger.info("Compiling loss")
+ self.loss = torch.compile(self.loss)
+
+ @endpoint
+ async def setup(self):
+ # TODO: update ForgeEngine to not use ForgeJobConfig
+ engine_config = {f.name: getattr(self, f.name) for f in fields(self)}
+ for key in {
+ "loss",
+ "state_dict_key",
+ "use_dcp",
+ "dcp_path",
+ }:
+ engine_config.pop(key) # Not part of job config
+ self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
+ self.engine.checkpointer.load(step=self.step)
+ self.engine.optimizers.zero_grad()
+
+ def forward_backward(
+ self, inputs: dict[str, Tensor], targets: dict[str, Tensor]
+ ) -> Tensor:
+ model_parts = self.engine.model_parts
+ parallel_dims = self.engine.parallel_dims
+ optional_context_parallel_ctx = None
+ if parallel_dims.pp_enabled:
+ raise NotImplementedError("PP not implemented yet")
+ else:
+ with self.engine.train_context(optional_context_parallel_ctx):
+ assert len(model_parts) == 1
+ with self.engine.maybe_enable_amp:
+ logits = model_parts[0](**inputs)
+ loss = self.loss(logits, **targets)
+ del logits # Free to before bwd to avoid peaking memory
+ loss.backward()
+ return loss
+
+ @endpoint
+ async def train_step(
+ self, inputs: list[dict[str, Tensor]], targets: list[dict[str, Tensor]]
+ ) -> float:
+ t = Tracer("rl_trainer_perf/step", timer="gpu", track_memory=True)
+ t.start()
+
+ self.engine.gc_handler.run(self.step)
+ local_inputs = inputs[self.engine.dp_rank]
+ local_targets = targets[self.engine.dp_rank]
+ batch_to_device(local_inputs, self.engine.device)
+ batch_to_device(local_targets, self.engine.device)
+
+ loss = self.forward_backward(local_inputs, local_targets)
+ torch.distributed.all_reduce(loss)
+
+ t.step("forward_backward")
+
+ current_lr = self.engine.lr_schedulers.schedulers[0].get_last_lr()[0]
+ record_metric("rl_trainer/learning_rate", current_lr, Reduce.MIN)
+
+ self.engine.optimizers.step()
+ self.engine.optimizers.zero_grad()
+ self.engine.lr_schedulers.step()
+ t.step("optimizer_step")
+
+ # TODO: delete item() to avoid cpu-gpu sync
+ loss = loss.detach().item()
+ record_metric("rl_trainer/loss", loss, Reduce.MEAN)
+
+ # These are placeholder values until the loss function exposes these metrics
+ # record_metric("rl_trainer/step/avg_kl_divergence", 0.0, Reduce.MEAN)
+ # record_metric("rl_trainer/step/std_kl_divergence", 0.0, Reduce.STD)
+ # record_metric("rl_trainer/step/avg_policy_entropy", 0.0, Reduce.MEAN)
+
+ self.step += 1
+ self.engine.checkpointer.save(
+ curr_step=self.step,
+ last_step=self.step == self.num_training_steps,
+ )
+ t.step("save_checkpoint")
+ t.stop()
+ return loss
+
+ @endpoint
+ async def push_weights(self, policy_version: int) -> None:
+ """Push weights to torchstore in HF format."""
+ logger.info(f"Pushing weights for policy version {policy_version}")
+
+ start_time = time.perf_counter()
+ if "model" not in self.engine.checkpointer.states:
+ raise RuntimeError("Model state not found in checkpointer state")
+
+ sd = self.engine.checkpointer.states["model"].state_dict()
+ flattened_state_dict, _ = flatten_state_dict(sd)
+ if self.engine.checkpointer.sd_adapter is None:
+ raise RuntimeError(
+ "Trying to save checkpoint in HF safetensors format, but sd_adapter is not provided."
+ )
+ hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict)
+ if self.use_dcp:
+ key = get_dcp_whole_state_dict_key(policy_version)
+ dcp_id = f"{self.dcp_path}/{key}"
+ storage_writer = torch.distributed.checkpoint.FileSystemWriter(
+ dcp_id, single_file_per_rank=False, thread_count=8
+ )
+ metadata = dcp.save(storage_writer=storage_writer, state_dict=hf_state_dict)
+ dcp_handle = DcpHandle(
+ checkpoint_id=dcp_id,
+ metadata=metadata,
+ param_names=hf_state_dict.keys(),
+ )
+ await ts.put(key, dcp_handle)
+ else:
+ for name, param in hf_state_dict.items():
+ key = get_param_key(policy_version, name)
+ await ts.put(key, param)
+ end_time = time.perf_counter()
+ logger.info("Completed weights push in %.2f seconds", end_time - start_time)
+
+ @endpoint
+ async def cleanup(self) -> None:
+ if self.engine.checkpointer:
+ self.engine.checkpointer.close()
diff --git a/src/forge/api/__init__.py b/src/forge/api/__init__.py
new file mode 100644
index 000000000..b9aba00af
--- /dev/null
+++ b/src/forge/api/__init__.py
@@ -0,0 +1,32 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Forge public API module.
+
+This module defines the public interfaces that all Forge implementations conform to.
+"""
+
+from forge.api.trainer import Trainer
+from forge.api.types import (
+ ForwardBackwardResult,
+ LossFn,
+ OptimStepResult,
+ ParallelismConfig,
+ TextTrainBatch,
+ TrainerConfig,
+ TrainerStatus,
+)
+
+__all__ = [
+ "Trainer",
+ "TextTrainBatch",
+ "ForwardBackwardResult",
+ "OptimStepResult",
+ "TrainerConfig",
+ "TrainerStatus",
+ "ParallelismConfig",
+ "LossFn",
+]
diff --git a/src/forge/api/trainer.py b/src/forge/api/trainer.py
new file mode 100644
index 000000000..16d307bbf
--- /dev/null
+++ b/src/forge/api/trainer.py
@@ -0,0 +1,298 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Protocol, runtime_checkable
+
+import torch
+
+from forge.api.types import (
+ ForwardBackwardResult,
+ LossFn,
+ OptimStepResult,
+ TextTrainBatch,
+ TrainerConfig,
+ TrainerStatus,
+)
+
+
+@runtime_checkable
+class Trainer(Protocol):
+ """Protocol defining the standard interface for all Forge trainers.
+
+ Trainer implementations are expected to accept a default loss function at
+ initialization time. This loss function is used when loss_fn is not
+ provided to forward_backward(). The default loss should follow the
+ LossFn signature.
+ """
+
+ async def forward_backward(
+ self, batch: TextTrainBatch, loss_fn: LossFn | None = None
+ ) -> ForwardBackwardResult:
+ """Execute forward pass and backward pass for one batch of data.
+
+ Basic usage - single batch per optimizer step:
+ >>> batch = TextTrainBatch(
+ >>> input_ids=torch.tensor([[1, 2, 3, 4, 5]]),
+ >>> target_ids=torch.tensor([[2, 3, 4, 5, 6]]),
+ >>> )
+ >>> result = await trainer.forward_backward(batch)
+ >>> await trainer.optim_step() # Apply gradients
+
+ To accumulate gradients over multiple batches before optimizer step:
+ >>> await trainer.forward_backward(batch1) # Accumulates
+ >>> await trainer.forward_backward(batch2) # Accumulates another batch
+ >>> await trainer.optim_step() # Apply all accumulated gradients
+
+ Custom loss function for specific batches:
+ >>> def custom_loss(outputs: dict[str, Any], batch: TextTrainBatch) -> torch.Tensor:
+ >>> # Custom loss computation (e.g., PPO clip, DPO, cut cross entropy, etc.)
+ >>> logits = outputs["logits"]
+ >>> # ... compute loss from logits, or use other outputs like hidden_states
+ >>> return loss
+ >>>
+ >>> result = await trainer.forward_backward(batch, loss_fn=custom_loss)
+
+ Args:
+ batch: TextTrainBatch containing input_ids, target_ids, and optional
+ target_mask/target_weights. See forge.api.types.TextTrainBatch for details.
+ loss_fn: Optional custom loss function. If None, uses the loss function
+ configured at trainer creation. Signature: (outputs, batch) -> loss
+ where outputs is a dict with at least "logits" key.
+ Useful for mixed training objectives or experimentation.
+
+ Returns:
+ ForwardBackwardResult containing loss and metrics
+
+ Note:
+ The default loss function is configured at trainer creation time via the
+ `loss` parameter. The `loss_fn` parameter here allows per-batch override.
+ All loss functions should accept (outputs: dict[str, Any], batch: TextTrainBatch)
+ where outputs contains at minimum a "logits" key.
+ """
+ ...
+
+ async def optim_step(self) -> OptimStepResult:
+ """Apply optimizer step using accumulated gradients, then clear gradients.
+
+ This method:
+ 1. Applies accumulated gradients via the optimizer
+ 2. Steps the learning rate scheduler
+ 3. Clears all gradients (zero_grad)
+ 4. Increments the training step counter
+ 5. May trigger automatic checkpointing (implementation-dependent)
+
+ Gradients must have been accumulated via forward_backward() calls before
+ calling this method.
+
+ Returns:
+ OptimStepResult containing step number, learning rate, and accumulated batch count
+
+ Example:
+ >>> # Accumulate over 4 batches
+ >>> for batch in batches[:4]:
+ >>> await trainer.forward_backward(batch)
+ >>> result = await trainer.optim_step()
+ >>> result.step
+ 1000
+ >>> result.learning_rate
+ 0.0001
+ >>> result.accumulated_microbatches
+ 4
+ """
+ ...
+
+ async def clear_gradients(self) -> None:
+ """Clear accumulated gradients without applying them.
+
+ Use this when you need to discard accumulated gradients without performing
+ an optimizer step. Common scenarios:
+ - Exception during gradient accumulation
+ - Skipping a training step due to some condition
+ - Recovering from OOM or other errors
+
+ This is equivalent to calling optimizer.zero_grad() and resetting internal
+ accumulation counters.
+
+ Example - Error recovery:
+ >>> try:
+ >>> for batch in batches:
+ >>> await trainer.forward_backward(batch)
+ >>> await trainer.optim_step()
+ >>> except torch.cuda.OutOfMemoryError:
+ >>> await trainer.clear_gradients() # Discard partial gradients
+ >>> # Retry with smaller batches
+
+ Example - Conditional skip:
+ >>> await trainer.forward_backward(batch)
+ >>> if should_skip_step():
+ >>> await trainer.clear_gradients() # Don't apply these gradients
+ >>> else:
+ >>> await trainer.optim_step()
+ """
+ ...
+
+ async def forward(self, inputs: dict[str, torch.Tensor]) -> torch.Tensor:
+ """Run forward pass only, without backward pass (for evaluation/inference).
+
+ This method executes the model's forward pass without computing gradients.
+ Useful for:
+ - Evaluation on validation/test data
+ - Getting model predictions/logits
+ - Debugging model outputs
+
+ Args:
+ inputs: Dictionary containing model inputs. Typically includes:
+ - input_ids: torch.Tensor [batch_size, seq_len]
+ Other keys depend on the model architecture.
+
+ Returns:
+ Model output logits. Shape: [batch_size, seq_len, vocab_size]
+
+ Note:
+ This runs in torch.no_grad() context - no gradients are computed.
+
+ Example:
+ >>> eval_batch = {"input_ids": torch.tensor([[1, 2, 3, 4]])}
+ >>> logits = await trainer.forward(eval_batch) # [1, 4, vocab_size]
+ >>> predictions = logits.argmax(dim=-1) # [1, 4]
+ """
+ ...
+
+ async def save(
+ self,
+ name: str | None = None,
+ path: str | None = None,
+ weights_only: bool = False,
+ ) -> str:
+ """Save trainer state or weights to persistent storage.
+
+ By default, saves complete training state (model weights, optimizer state,
+ learning rate scheduler state, and step counter). Set weights_only=True to
+ save only model weights for inference/deployment.
+
+ Args:
+ name: Optional checkpoint name/identifier. If None, uses the current
+ step number (e.g., "step-1000" or "weights-step-1000").
+ path: Optional base directory or URI where checkpoint should be saved.
+ If None, uses the default checkpoint directory configured at trainer
+ creation. Supports different backends via URI schemes:
+ - `/local/path` - local filesystem
+ - `ts://key` - TorchStore
+ - `s3://bucket/key` - S3
+ weights_only: If True, saves only model weights (lighter, for inference).
+ If False (default), saves full training state including optimizer.
+
+
+ Returns:
+ Full path/URI where checkpoint was saved
+
+ Example:
+ >>> # Save full training state (default)
+ >>> path = await trainer.save(name="checkpoint-1000")
+ >>> path
+ "/default/checkpoint-1000"
+ >>>
+ >>> # Save weights only for inference
+ >>> path = await trainer.save(name="policy-v1", weights_only=True)
+ >>> path
+ "/default/policy-v1"
+ >>>
+ >>> # Save to TorchStore
+ >>> path = await trainer.save(name="best", path="ts://checkpoints")
+ >>> path
+ "ts://checkpoints/best"
+ """
+ ...
+
+ async def load(self, path: str | None = None) -> str:
+ """Load a previously saved checkpoint.
+
+ Restores training state from a checkpoint. Automatically handles both
+ full checkpoints and weights-only checkpoints.
+
+ Args:
+ path: Optional path or URI to the checkpoint to load. If None, loads
+ the most recent checkpoint from the default directory. Can be:
+ - `/local/path/checkpoint` - local filesystem
+ - `ts://key` - TorchStore
+ - `s3://bucket/key` - S3
+
+ Returns:
+ Path/URI that was loaded
+
+ Example:
+ >>> # Load latest checkpoint from default location
+ >>> path = await trainer.load()
+ >>> path
+ "/default/step-5000"
+ >>>
+ >>> # Load specific checkpoint by path
+ >>> path = await trainer.load("/checkpoints/step-5000")
+ >>> path
+ "/checkpoints/step-5000"
+ >>>
+ >>> # Load from TorchStore
+ >>> path = await trainer.load("ts://checkpoint-key")
+ >>> path
+ "ts://checkpoint-key"
+ """
+ ...
+
+ async def get_config(self) -> TrainerConfig:
+ """Get static trainer and model configuration.
+
+ Returns configuration information that doesn't change during training.
+ For runtime state like current step, use get_status() instead.
+
+ Returns:
+ TrainerConfig containing model name, model_config, and parallelism settings
+
+ Example:
+ >>> config = await trainer.get_config()
+ >>> config.model_name
+ "Qwen/Qwen2.5-7B"
+ >>> config.model_config["vocab_size"]
+ 151936
+ >>> config.parallelism.dp_degree
+ 4
+ >>> config.parallelism.device
+ "cuda:0"
+ """
+ ...
+
+ async def get_status(self) -> TrainerStatus:
+ """Get current runtime status of the trainer.
+
+ Returns dynamic information about the trainer's current state that changes
+ during training.
+
+ Returns:
+ TrainerStatus containing current step and accumulated batch count
+
+ Example:
+ >>> status = await trainer.get_status()
+ >>> status.step
+ 1000
+ >>> status.accumulated_microbatches
+ 2
+ """
+ ...
+
+ async def get_tokenizer(self):
+ """Get the tokenizer associated with this model.
+
+ Returns the tokenizer used for encoding/decoding text with this model.
+ Useful for preprocessing inputs or decoding model outputs.
+
+ Returns:
+ PreTrainedTokenizer: The HuggingFace tokenizer for this model
+
+ Example:
+ >>> tokenizer = await trainer.get_tokenizer()
+ >>> tokens = tokenizer.encode("Hello world")
+ >>> text = tokenizer.decode([1, 2, 3, 4])
+ """
+ ...
diff --git a/src/forge/api/types.py b/src/forge/api/types.py
new file mode 100644
index 000000000..c8a5c6cec
--- /dev/null
+++ b/src/forge/api/types.py
@@ -0,0 +1,196 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Type definitions for the Forge API."""
+
+from dataclasses import dataclass
+from typing import Any, Callable, TypeAlias
+
+import torch
+
+
+# Loss function signature: takes model outputs (as dict) and batch, returns scalar loss
+# The dict will typically contain logits, but may include other keys depending on use case.
+LossFn: TypeAlias = Callable[[dict[str, Any], "TextTrainBatch"], torch.Tensor]
+
+
+@dataclass
+class TextTrainBatch:
+ """A batch of text training data for forward_backward.
+
+ This dataclass defines the standard format for text training batches across all
+ Forge text trainers.
+
+ Attributes:
+ input_ids: Input token IDs. Shape: [batch_size, seq_len]
+ target_ids: Target token IDs for loss computation. Shape: [batch_size, seq_len]
+ target_mask: Mask indicating which tokens to compute loss on.
+ Shape: [batch_size, seq_len]. Values are 0 (ignore) or 1 (compute loss).
+ If None, computes loss on all tokens.
+ target_weights: Per-token weights for loss computation.
+ Shape: [batch_size, seq_len]. Used for importance weighting, such as
+ advantages in RL (GRPO, PPO) or custom loss weighting schemes.
+ If None, all tokens have weight 1.0.
+
+ Example:
+ >>> batch = TextTrainBatch(
+ >>> input_ids=torch.tensor([[1, 2, 3, 4, 5]]),
+ >>> target_ids=torch.tensor([[2, 3, 4, 5, 6]]),
+ >>> target_mask=torch.tensor([[0, 0, 1, 1, 1]]), # Only predict last 3 tokens
+ >>> target_weights=torch.tensor([[0, 0, 1.0, 0.8, 1.2]]), # Weight by advantage
+ >>> )
+ >>> result = await trainer.forward_backward(batch)
+ """
+
+ input_ids: torch.Tensor
+ target_ids: torch.Tensor
+ target_mask: torch.Tensor | None = None
+ target_weights: torch.Tensor | None = None
+
+
+@dataclass
+class ForwardBackwardResult:
+ """Result from a forward_backward pass.
+
+ Attributes:
+ loss: Loss value computed for the batch
+ metrics: Additional metrics computed during training (e.g., perplexity,
+ accuracy, KL divergence). May be empty if no additional metrics are tracked.
+ Values can be scalars, tensors, or other structured data depending on the loss.
+
+ Example:
+ >>> result = await trainer.forward_backward(batch)
+ >>> result.loss
+ 0.3542
+ >>> result.metrics
+ {"perplexity": 1.42, "kl_divergence": 0.05}
+ """
+
+ loss: float
+ metrics: dict[str, Any]
+
+
+@dataclass
+class OptimStepResult:
+ """Result from an optimizer step.
+
+ Attributes:
+ step: Training step number after this optimizer step
+ learning_rate: Current learning rate used for this step
+ accumulated_microbatches: Number of forward_backward calls that were
+ accumulated before this optimizer step. Useful for tracking gradient
+ accumulation behavior.
+
+ Example:
+ >>> result = await trainer.optim_step()
+ >>> result.step
+ 1000
+ >>> result.learning_rate
+ 0.0001
+ >>> result.accumulated_microbatches
+ 4
+ """
+
+ step: int
+ learning_rate: float
+ accumulated_microbatches: int
+
+
+@dataclass
+class ParallelismConfig:
+ """Parallelism configuration for distributed training.
+
+ Attributes:
+ dp_degree: Data parallel degree (number of data parallel replicas)
+ tp_degree: Tensor parallel degree (model sharding across devices)
+ pp_degree: Pipeline parallel degree (model sharding across pipeline stages)
+ cp_degree: Context parallel degree (sequence parallelism for long contexts)
+ ep_degree: Expert parallel degree (for MoE models)
+ world_size: Total number of processes in the distributed training job
+ dp_rank: Current data parallel rank (0 to dp_degree-1)
+ tp_rank: Current tensor parallel rank (0 to tp_degree-1)
+ device: Device identifier (e.g., "cuda:0", "cuda:1")
+
+ Example:
+ >>> config = await trainer.get_config()
+ >>> config.parallelism.dp_degree
+ 4
+ >>> config.parallelism.tp_degree
+ 2
+ >>> config.parallelism.pp_degree
+ 1
+ >>> config.parallelism.cp_degree
+ 1
+ >>> config.parallelism.ep_degree
+ 1
+ >>> config.parallelism.device
+ "cuda:0"
+ """
+
+ dp_degree: int
+ tp_degree: int
+ pp_degree: int
+ cp_degree: int
+ ep_degree: int
+ world_size: int
+ dp_rank: int
+ tp_rank: int
+ device: str
+
+
+@dataclass
+class TrainerConfig:
+ """Static trainer and model configuration.
+
+ This contains configuration information that doesn't change during training.
+
+ Attributes:
+ model_name: Name or path of the model being trained
+ model_config: Model architecture configuration. Common keys include:
+ - vocab_size: int - Size of the vocabulary
+ - hidden_size: int - Hidden dimension size
+ - num_layers: int - Number of transformer layers
+ - num_attention_heads: int - Number of attention heads
+ - max_seq_len: int - Maximum sequence length
+ parallelism: Parallelism configuration for distributed training
+
+ Example:
+ >>> config = await trainer.get_config()
+ >>> config.model_name
+ "Qwen/Qwen2.5-7B"
+ >>> config.model_config["vocab_size"]
+ 151936
+ >>> config.parallelism.dp_degree
+ 4
+ """
+
+ model_name: str
+ model_config: dict[str, Any]
+ parallelism: ParallelismConfig
+
+
+@dataclass
+class TrainerStatus:
+ """Runtime status of the trainer.
+
+ This contains dynamic information about the trainer's current state that
+ changes during training.
+
+ Attributes:
+ step: Current training step
+ accumulated_microbatches: Number of batches accumulated since the last
+ optim_step. Will be 0 if gradients were just applied/cleared.
+
+ Example:
+ >>> status = await trainer.get_status()
+ >>> status.step
+ 1000
+ >>> status.accumulated_microbatches
+ 2
+ """
+
+ step: int
+ accumulated_microbatches: int
diff --git a/src/forge/cli/download.py b/src/forge/cli/download.py
deleted file mode 100644
index 69ebde9aa..000000000
--- a/src/forge/cli/download.py
+++ /dev/null
@@ -1,148 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-import argparse
-
-import json
-import os
-import textwrap
-import traceback
-
-from pathlib import Path
-
-from huggingface_hub import snapshot_download
-from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
-
-from forge.cli.subcommand import Subcommand
-
-# TODO: update this
-REPO_ID_FNAME = "original_repo_id"
-
-
-class Download(Subcommand):
- """Holds all the logic for the `forge download` subcommand."""
-
- def __init__(self, subparsers: argparse._SubParsersAction):
- super().__init__()
- self._parser = subparsers.add_parser(
- "download",
- prog="forge download",
- usage="forge download [OPTIONS]",
- help="Download a model from the Hugging Face Hub.",
- description="Download a model from the Hugging Face Hub.",
- epilog=textwrap.dedent(
- """\
- examples:
- # Download a model from the Hugging Face Hub with a Hugging Face API token
- $ forge download meta-llama/Llama-2-7b-hf --hf-token
- Successfully downloaded model repo and wrote to the following locations:
- /tmp/Llama-2-7b-hf/config.json
- /tmp/Llama-2-7b-hf/README.md
- /tmp/Llama-2-7b-hf/consolidated.00.pth
- ...
-
- # Download an ungated model from the Hugging Face Hub
- $ forge download mistralai/Mistral-7B-Instruct-v0.2 --output-dir /tmp/model
- Successfully downloaded model repo and wrote to the following locations:
- /tmp/model/config.json
- /tmp/model/README.md
- /tmp/model/model-00001-of-00002.bin
- ...
-
- For a list of all models, visit the Hugging Face Hub
- https://huggingface.co/models.
- """
- ),
- formatter_class=argparse.RawTextHelpFormatter,
- )
- self._add_arguments()
- self._parser.set_defaults(func=self._download_cmd)
-
- def _add_arguments(self) -> None:
- """Add arguments to the parser."""
- self._parser.add_argument(
- "repo_id",
- type=str,
- help="Name of the repository on Hugging Face Hub.",
- )
- self._parser.add_argument(
- "--output-dir",
- type=Path,
- required=False,
- default=None,
- help="Directory in which to save the model. Defaults to `/tmp/`.",
- )
- self._parser.add_argument(
- "--hf-token",
- type=str,
- required=False,
- default=os.getenv("HF_TOKEN", None),
- help="Hugging Face API token. Needed for gated models like Llama2.",
- )
- self._parser.add_argument(
- "--ignore-patterns",
- type=str,
- required=False,
- help="If provided, files matching any of the patterns are not downloaded. Example: '*.safetensors'. "
- "Only supported for Hugging Face Hub models.",
- )
-
- def _download_cmd(self, args: argparse.Namespace) -> None:
- return self._download_from_huggingface(args)
-
- def _download_from_huggingface(self, args: argparse.Namespace) -> None:
- """Downloads a model from the Hugging Face Hub."""
- # Download the tokenizer and PyTorch model files
-
- # Default output_dir is `/tmp/`
- output_dir = args.output_dir
- if output_dir is None:
- model_name = args.repo_id.split("/")[-1]
- output_dir = Path("/tmp") / model_name
-
- print(f"Ignoring files matching the following patterns: {args.ignore_patterns}")
- try:
- true_output_dir = snapshot_download(
- args.repo_id,
- local_dir=output_dir,
- ignore_patterns=args.ignore_patterns,
- token=args.hf_token,
- )
- except GatedRepoError:
- if args.hf_token:
- self._parser.error(
- "It looks like you are trying to access a gated repository. Please ensure you "
- "have access to the repository."
- )
- else:
- self._parser.error(
- "It looks like you are trying to access a gated repository. Please ensure you "
- "have access to the repository and have provided the proper Hugging Face API token "
- "using the option `--hf-token` or by running `huggingface-cli login`."
- "You can find your token by visiting https://huggingface.co/settings/tokens"
- )
- except RepositoryNotFoundError:
- self._parser.error(
- f"Repository '{args.repo_id}' not found on the Hugging Face Hub."
- )
- except Exception as e:
- tb = traceback.format_exc()
- msg = f"Failed to download {args.repo_id} with error: '{e}' and traceback: {tb}"
- self._parser.error(msg)
-
- # save the repo_id. This is necessary because the download step is a separate command
- # from the rest of the CLI. When saving a model adapter, we have to add the repo_id
- # to the adapter config.
- # TODO: this needs to be updated when we start using HF cache
- file_path = os.path.join(true_output_dir, REPO_ID_FNAME + ".json")
- with open(file_path, "w") as json_file:
- json.dump({"repo_id": args.repo_id}, json_file, indent=4)
-
- print(
- "Successfully downloaded model repo and wrote to the following locations:",
- *list(Path(true_output_dir).iterdir()),
- sep="\n",
- )
diff --git a/src/forge/cli/forge.py b/src/forge/cli/forge.py
deleted file mode 100644
index 7e5d2ac73..000000000
--- a/src/forge/cli/forge.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-import argparse
-
-from forge.cli.download import Download
-from forge.cli.run import Run
-
-
-class ForgeCLIParser:
- """Holds all information related to running the CLI"""
-
- def __init__(self):
- # Initialize the top-level parser
- self._parser = argparse.ArgumentParser(
- prog="forge",
- description="Welcome to the torchforge CLI!",
- add_help=True,
- )
- # Default command is to print help
- self._parser.set_defaults(func=lambda args: self._parser.print_help())
-
- # Add subcommands
- subparsers = self._parser.add_subparsers(title="subcommands")
- Download.create(subparsers)
- Run.create(subparsers)
-
- def parse_args(self) -> argparse.Namespace:
- """Parse CLI arguments"""
- return self._parser.parse_args()
-
- def run(self, args: argparse.Namespace) -> None:
- """Execute CLI"""
- args.func(args)
-
-
-def main():
- parser = ForgeCLIParser()
- args = parser.parse_args()
- parser.run(args)
-
-
-if __name__ == "__main__":
- main()
diff --git a/src/forge/cli/run.py b/src/forge/cli/run.py
deleted file mode 100644
index 4a556c1f8..000000000
--- a/src/forge/cli/run.py
+++ /dev/null
@@ -1,121 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-import argparse
-import os
-import sys
-import textwrap
-
-from pathlib import Path
-
-from torch.distributed.elastic.multiprocessing.errors import record
-from torch.distributed.run import get_args_parser as get_torchrun_args_parser, run
-
-import forge
-from forge.cli.subcommand import Subcommand
-
-ROOT = Path(forge.__file__).parent.parent
-
-
-class Run(Subcommand):
- """Holds all the logic for the `forge run` subcommand."""
-
- def __init__(self, subparsers):
- super().__init__()
- self._parser = subparsers.add_parser(
- "run",
- prog="forge run",
- help="Run a recipe. For distributed recipes, this supports all torchrun arguments.",
- description="Run a recipe. For distributed recipes, this supports all torchrun arguments.",
- usage="forge run [TORCHRUN-OPTIONS] --config [RECIPE-OPTIONS]",
- epilog=textwrap.dedent(
- """\
- examples:
-
- # Run SFT recipe with default values
- $ forge run --nproc_per_node 4 apps/sft/sft.py --config apps/sft/configs/llama3_8b.yaml
- """
- ),
- formatter_class=argparse.RawTextHelpFormatter,
- )
- self._add_arguments()
- self._parser.set_defaults(func=self._run_cmd)
-
- def _add_arguments(self) -> None:
- """Add arguments to the parser.
-
- This is a bit hacky since we need to add the torchrun arguments to our parser.
- This grabs the argparser from torchrun, iterates over it's actions, and adds them
- to our parser. We rename the training_script and training_script_args to recipe and recipe_args
- respectively. In addition, we leave out the help argument since we add it manually to ours.
- """
- torchrun_argparser = get_torchrun_args_parser()
- for action in torchrun_argparser._actions:
- if action.dest == "training_script":
- action.dest = "recipe"
- action.help = """Path to recipe to be launched followed by args."""
- elif action.dest == "training_script_args":
- action.dest = "recipe_args"
- action.help = "Args to be passed to the recipe."
- elif action.dest == "help":
- continue
- self._parser._add_action(action)
-
- @record
- def _run_distributed(self, args: argparse.Namespace):
- """Run a recipe with torchrun."""
- print("Running with torchrun...")
- # Have to reset the argv so that the recipe can be run with the correct arguments
- args.training_script = args.recipe
- args.training_script_args = args.recipe_args
-
- # If the user does not explicitly pass a rendezvous endpoint, run in standalone mode.
- # This allows running multiple distributed training jobs simultaneously.
- if not args.rdzv_endpoint:
- args.standalone = True
-
- args.module = True
- run(args)
-
- def _convert_to_dotpath(self, recipe_path: str) -> str:
- """Convert a custom recipe path to a dot path that can be run as a module.
-
- Args:
- recipe_path (str): The path of the recipe.
-
- Returns:
- The dot path of the recipe.
- """
- filepath, _ = os.path.splitext(recipe_path)
- return filepath.replace("/", ".")
-
- def _run_cmd(self, args: argparse.Namespace):
- """Run a recipe."""
- # We have to assume that the recipe supports distributed training
- supports_distributed = True
- recipe_path, config_path = None, None
-
- # Try to find config string in args
- try:
- config_idx = args.recipe_args.index("--config") + 1
- config_str = args.recipe_args[config_idx]
- except ValueError:
- self._parser.error("The '--config' argument is required.")
-
- # Get recipe path
- recipe_path = self._convert_to_dotpath(args.recipe)
-
- # Get config path
- config_path = config_str
-
- # Prepare args
- args.recipe = recipe_path
- args.recipe_args[config_idx] = config_path
-
- # Make sure user code in current directory is importable
- sys.path.append(os.getcwd())
-
- self._run_distributed(args)
diff --git a/src/forge/cli/subcommand.py b/src/forge/cli/subcommand.py
deleted file mode 100644
index db298a0b0..000000000
--- a/src/forge/cli/subcommand.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-
-class Subcommand:
- def __init__(self, *args, **kwargs):
- pass
-
- @classmethod
- def create(cls, *args, **kwargs):
- return cls(*args, **kwargs)
-
- def _add_arguments(self):
- pass
diff --git a/src/forge/controller/__init__.py b/src/forge/controller/__init__.py
index 6d1821d64..a579200e9 100644
--- a/src/forge/controller/__init__.py
+++ b/src/forge/controller/__init__.py
@@ -4,6 +4,19 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .actor import ForgeActor
-from .proc_mesh import get_proc_mesh, stop_proc_mesh
+from .provisioner import (
+ get_proc_mesh,
+ host_mesh_from_proc,
+ init_provisioner,
+ shutdown,
+ stop_proc_mesh,
+)
-__all__ = ["stop_proc_mesh", "get_proc_mesh", "ForgeActor"]
+__all__ = [
+ "ForgeActor",
+ "get_proc_mesh",
+ "stop_proc_mesh",
+ "init_provisioner",
+ "shutdown",
+ "host_mesh_from_proc",
+]
diff --git a/src/forge/controller/actor.py b/src/forge/controller/actor.py
index bb495b641..1b8d0a074 100644
--- a/src/forge/controller/actor.py
+++ b/src/forge/controller/actor.py
@@ -8,11 +8,19 @@
import math
import sys
-from typing import Any, Type, TypeVar
+from typing import Any, Type, TYPE_CHECKING, TypeVar
from monarch.actor import Actor, current_rank, current_size, endpoint
-from forge.controller.proc_mesh import get_proc_mesh, stop_proc_mesh
+if TYPE_CHECKING:
+ from monarch._src.actor.actor_mesh import ActorMesh
+
+from forge.controller.provisioner import (
+ get_proc_mesh,
+ register_actor,
+ register_service,
+ stop_proc_mesh,
+)
from forge.types import ProcessConfig, ServiceConfig
@@ -22,11 +30,36 @@
class ForgeActor(Actor):
+ """
+ Base class for Forge actors with configurable resource attributes.
+
+ The initialization sets up logging configuration with rank/size information and
+ initializes the actor's process mesh reference. The rank and size are automatically
+ determined from the current execution context.
+
+ Args:
+ *args: Variable length argument list passed to the parent Actor class.
+ **kwargs: Arbitrary keyword arguments passed to the parent Actor class.
+ """
+
procs: int = 1
+ """Number of processes to use for this actor. Defaults to 1."""
+
hosts: int | None = None
+ """Number of hosts to distribute the actor across. If None, uses as many
+ hosts as needed to accommodate the requested processes. Defaults to None."""
+
with_gpus: bool = False
+ """Whether to allocate GPU resources for this actor. Defaults to False."""
+
num_replicas: int = 1
+ """Number of replicas to create when spawning as a service.
+ Only applies when using as_service(). Defaults to 1."""
+
mesh_name: str | None = None
+ """Optional name for the process mesh used by this actor.
+ If None, a default name will be generated. Defaults to None."""
+
_extra_config: dict[str, Any] = {}
def __init__(self, *args, **kwargs):
@@ -69,23 +102,35 @@ def options(
`.as_actor()` or `.as_service()`. Each call creates a separate subclass, so
multiple different configurations can coexist without interfering with each other.
- ---- Usage Examples ----
+ Examples:
+
+ * Pre-configure a service with multiple replicas:
+
+ .. code-block:: python
+
+ service = await MyForgeActor.options(num_replicas=2, procs=2).as_service(...)
+ await service.shutdown()
+
+ * Default usage without calling options:
- # Pre-configure a service with multiple replicas
- service = await MyForgeActor.options(num_replicas=2, procs=2).as_service(...)
- await service.shutdown()
+ .. code-block:: python
- # Default usage without calling options
- service = await MyForgeActor.as_service(...)
- await service.shutdown()
+ service = await MyForgeActor.as_service(...)
+ await service.shutdown()
- # Pre-configure a single actor
- actor = await MyForgeActor.options(procs=1, hosts=1).as_actor(...)
- await actor.shutdown()
+ * Pre-configure a single actor
- # Default usage without calling options
- actor = await MyForgeActor.as_actor(...)
- await actor.shutdown()
+ .. code-block:: python
+
+ actor = await MyForgeActor.options(procs=1, hosts=1).as_actor(...)
+ await actor.shutdown()
+
+ * Default usage without calling options
+
+ .. code-block:: python
+
+ actor = await MyForgeActor.as_actor(...)
+ await MyForgeActor.shutdown(actor)
"""
attrs = {
@@ -124,10 +169,13 @@ async def as_service(
}
cfg = ServiceConfig(**cfg_kwargs)
- logger.info("Spawning Service for %s", cls.__name__)
+ logger.info(f"Spawning service {cls.__name__}")
service = Service(cfg, cls, actor_args, actor_kwargs)
await service.__initialize__()
- return ServiceInterface(service, cls)
+ service_interface = ServiceInterface(service, cls)
+ # Register this service with the provisioner so it can cleanly shut this down
+ await register_service(service_interface)
+ return service_interface
@endpoint
async def setup(self):
@@ -144,30 +192,8 @@ async def setup(self):
"""
pass
- @endpoint
- async def set_env(self, addr: str, port: str):
- """A temporary workaround to set master addr/port.
-
- TODO - issues/144. This should be done in proc_mesh creation.
- The ideal path:
- - Create a host mesh
- - Grab a host from host mesh, from proc 0 spawn an actor that
- gets addr/port
- - Spawn procs on the HostMesh with addr/port, setting the
- addr/port in bootstrap.
-
- We can't currently do this because HostMesh only supports single
- proc_mesh creation at the moment. This will be possible once
- we have "proper HostMesh support".
-
- """
- import os
-
- os.environ["MASTER_ADDR"] = addr
- os.environ["MASTER_PORT"] = port
-
@classmethod
- async def launch(cls, *args, **kwargs) -> "ForgeActor":
+ async def launch(cls, *args, **kwargs) -> "ActorMesh":
"""Provisions and deploys a new actor.
This method is used by `Service` to provision a new replica.
@@ -193,10 +219,6 @@ async def launch(cls, *args, **kwargs) -> "ForgeActor":
actor_name = kwargs.pop("name", cls.__name__)
actor = proc_mesh.spawn(actor_name, cls, *args, **kwargs)
actor._proc_mesh = proc_mesh
-
- if hasattr(proc_mesh, "_hostname") and hasattr(proc_mesh, "_port"):
- host, port = proc_mesh._hostname, proc_mesh._port
- await actor.set_env.call(addr=host, port=port)
await actor.setup.call()
return actor
@@ -209,8 +231,10 @@ async def as_actor(cls: Type[T], *args, **actor_kwargs) -> T:
`procs`) are used to construct a ProcessConfig instance.
If no configuration was stored, defaults to a single process with no GPU.
"""
- logger.info("Spawning single actor %s", cls.__name__)
+ logger.info(f"Spawning actor {cls.__name__}")
actor = await cls.launch(*args, **actor_kwargs)
+ # Register this actor with the provisioner so it can cleanly shut this down
+ await register_actor(actor)
return actor
@classmethod
diff --git a/src/forge/controller/launcher.py b/src/forge/controller/launcher.py
index cd54c00b0..2f7e93aec 100644
--- a/src/forge/controller/launcher.py
+++ b/src/forge/controller/launcher.py
@@ -4,26 +4,31 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+"""Launcher specific logic (i.e. SLURM, k8s when supported, etc.)"""
+
+import copy
import getpass
import os
-import socket
import subprocess
+import tempfile
import uuid
from typing import Any
import monarch
-
import torchx.specs as specs
+from forge.types import Launcher, LauncherConfig
from monarch._rust_bindings.monarch_hyperactor.alloc import AllocConstraints
+from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
+from monarch._rust_bindings.monarch_hyperactor.config import configure
from monarch._src.actor.allocator import RemoteAllocator, TorchXRemoteAllocInitializer
from monarch.actor import Actor, endpoint, ProcMesh
from monarch.tools import commands
-from monarch.tools.commands import info
+from monarch.tools.commands import create, info
from monarch.tools.components import hyperactor
from monarch.tools.config import Config, Workspace
-from forge.types import Launcher, LauncherConfig
+_MAST_AVAILABLE = False
try:
from monarch._src.actor.actor_mesh import current_rank
@@ -31,29 +36,69 @@
from monarch.tools.components.meta import hyperactor as meta_hyperactor
from torchx.specs import AppState
from torchx.specs.fb.component_helpers import Packages
+
+ _MAST_AVAILABLE = True
except ImportError as e:
- print(f"Warning: Monarch meta/fb inetrnal imports failed: {e}")
- print("Monarch functionality will be limited")
+ # This means there is an error with MAST
+ pass
JOB_NAME_KEY = "job_name"
LAUNCHER_KEY = "launcher"
-def _get_port() -> str:
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(("localhost", 0))
- addr = s.getsockname()
- port = addr[1]
- return str(port)
+def mount_mnt_directory(mount_dst: str) -> None:
+ """Mounts the MAST remote directory to the specified destination.
+ This function mounts a remote workspace directory that contains huggingface models
+ and other shared resources needed for training.
-class SetupActor(Actor):
- @endpoint
- def get_info(self) -> [str, str]:
- return socket.gethostname(), _get_port()
+ Args:
+ mount_dst: Destination path where the directory should be mounted (e.g., "/mnt/wsfuse")
+ """
+ # Sanity check of the mounted directory
+ sanity_path = os.path.join(mount_dst, "huggingface_models/")
+ if os.path.exists(sanity_path):
+ return
+
+ # Otherwise, mount the directory
+ if not os.path.exists(mount_dst):
+ os.makedirs(mount_dst, exist_ok=True)
+ # Store original LD_LIBRARY_PATH to restore after mounting
+ original_ld_library_path = os.environ.get("LD_LIBRARY_PATH", "")
-class MastSetupActor(SetupActor):
+ try:
+ clean_env = os.environ.copy()
+ if "LD_LIBRARY_PATH" in clean_env:
+ del clean_env["LD_LIBRARY_PATH"]
+
+ subprocess.run(
+ [
+ "/packages/oil.oilfs/oilfs-wrapper",
+ "ws://ws.ai.pci0ai/genai_fair_llm",
+ mount_dst,
+ ],
+ capture_output=True,
+ text=True,
+ check=True,
+ env=clean_env,
+ )
+ print("Done mounting")
+ except subprocess.CalledProcessError as e:
+ print(f"Get error during mounting {e}, Stderr: {e.stderr}, Stdout: {e.stdout}")
+ finally:
+ # Restore original LD_LIBRARY_PATH
+ if original_ld_library_path:
+ os.environ["LD_LIBRARY_PATH"] = original_ld_library_path
+ elif "LD_LIBRARY_PATH" in os.environ:
+ del os.environ["LD_LIBRARY_PATH"]
+
+ assert os.path.exists(
+ sanity_path
+ ), f"Did not find directory {sanity_path}; something wrong with mounting."
+
+
+class MastSetupActor(Actor):
@endpoint
def mount(self, mount_dst: str):
point = current_rank()
@@ -63,53 +108,7 @@ def mount(self, mount_dst: str):
if current_rank().rank % proc_count != 0:
# Only use one rank per host to mount the directory
return
- self.mount_mnt_directory(mount_dst)
-
- def mount_mnt_directory(self, mount_dst: str) -> None:
- # Sanity check of the mounted directory
- sanity_path = os.path.join(mount_dst, "huggingface_models/")
- if os.path.exists(sanity_path):
- print(f"Found directory {sanity_path}; skip mounting.")
- return
-
- # Otherwise, mount the directory
- if not os.path.exists(mount_dst):
- os.makedirs(mount_dst, exist_ok=True)
-
- # Store original LD_LIBRARY_PATH to restore after mounting
- original_ld_library_path = os.environ.get("LD_LIBRARY_PATH", "")
-
- try:
- clean_env = os.environ.copy()
- if "LD_LIBRARY_PATH" in clean_env:
- del clean_env["LD_LIBRARY_PATH"]
-
- subprocess.run(
- [
- "/packages/oil.oilfs/oilfs-wrapper",
- "ws://ws.ai.pci0ai/genai_fair_llm",
- mount_dst,
- ],
- capture_output=True,
- text=True,
- check=True,
- env=clean_env,
- )
- print("Done mounting")
- except subprocess.CalledProcessError as e:
- print(
- f"Get error during mounting {e}, Stderr: {e.stderr}, Stdout: {e.stdout}"
- )
- finally:
- # Restore original LD_LIBRARY_PATH
- if original_ld_library_path:
- os.environ["LD_LIBRARY_PATH"] = original_ld_library_path
- elif "LD_LIBRARY_PATH" in os.environ:
- del os.environ["LD_LIBRARY_PATH"]
-
- assert os.path.exists(
- sanity_path
- ), f"Did not find directory {sanity_path}; something wrong with mounting."
+ mount_mnt_directory(mount_dst)
class BaseLauncher:
@@ -119,30 +118,43 @@ async def initialize(self) -> None:
async def get_allocator(self, name: str, num_hosts: int) -> tuple[Any, Any, str]:
pass
- async def remote_setup(self, procs: ProcMesh) -> tuple[str, int]:
+ async def remote_setup(self, procs: ProcMesh) -> None:
pass
class Slurmlauncher(BaseLauncher):
+ def __init__(
+ self,
+ cfg: LauncherConfig,
+ ):
+ self.cfg = cfg
+
async def initialize(self) -> None:
- pass
+ # HostMesh currently requires explicit configuration
+ # of the underlying transport from client to mesh.
+ # This can be removed in the future once this has been removed.
+ configure(default_transport=ChannelTransport.TcpWithHostname)
async def get_allocator(self, name: str, num_hosts: int) -> tuple[Any, Any, str]:
appdef = hyperactor.host_mesh(
image="test", meshes=[f"{name}:{num_hosts}:gpu.small"]
)
for role in appdef.roles:
- # Note - this is hardcoded to SLURM
- # We got this with sinfo
- role.resource.memMB = 2062607
- role.resource.cpu = 128
- role.resource.gpu = 8
+ role.resource.memMB = self.cfg.memMB
+ role.resource.cpu = self.cfg.cpu
+ role.resource.gpu = self.cfg.gpu
- # TODO - multi scheduler support
+ # Note - we cannot add in an empty workspace, so we create a fake temporary one
+ temp_workspace = tempfile.mkdtemp(prefix="forge_workspace_")
server_config = Config(
scheduler="slurm",
+ scheduler_args={
+ "account": self.cfg.account,
+ "qos": self.cfg.qos,
+ "time": "72:00:00",
+ },
appdef=appdef,
- workspace=monarch.tools.config.workspace.Workspace(dirs=[""]),
+ workspace=monarch.tools.config.workspace.Workspace(dirs=[temp_workspace]),
)
server_info = await commands.get_or_create(
"forge_job",
@@ -156,24 +168,54 @@ async def get_allocator(self, name: str, num_hosts: int) -> tuple[Any, Any, str]
server_name = f"slurm:///{server_info.name}"
return alloc, None, server_name # (Allocator, AllocConstraints, SeverName)
- async def remote_setup(self, procs: ProcMesh) -> tuple[str, int]:
- setup = procs.spawn(f"setup-{uuid.uuid1()}", SetupActor)
- return await setup.get_info.choose()
+ async def remote_setup(self, procs: ProcMesh) -> None:
+ return
+
+
+class MastLauncher(BaseLauncher):
+ """Launcher for MAST (Meta's internal cluster scheduler).
+
+ This launcher supports two modes of operation:
+
+ 1. Non-detached mode (detached=False):
+ - Client runs on your local machine/devserver
+ - Only worker roles (GPU hosts) are launched in MAST
+ - Client connects to workers remotely via provisioner
+ 2. Detached mode (detached=True):
+ - Client runs entirely inside MAST as a separate role
+ - Both client role (CPU-only) and worker roles (GPU) are launched in MAST
+ - Client role executes the training script with --mode=remote
+ - Everything runs in the cluster, no client needed on local machine
-class Mastlauncher(BaseLauncher):
- def __init__(self, cfg: LauncherConfig | None = None):
+ Args:
+ cfg: Launcher configuration including job name, services, and actors
+ detached: If True, adds a client role to the MAST job appdef that runs
+ the training script inside MAST. If False, only launches worker
+ roles and expects the client to run on local machine.
+ extra_args: Additional CLI arguments to pass through to the client role.
+
+ """
+
+ def __init__(
+ self,
+ cfg: LauncherConfig | None = None,
+ detached: bool = False,
+ extra_args: list = None,
+ ):
assert cfg is not None
self.cfg = cfg
+ self.detached = detached
self.default_monarch_port = 26600
+ self.extra_args = extra_args or []
self.scheduler_name = "mast_conda"
- # TODO: enabe taking this from config
+ # TODO: enable taking this from config
self.sku = "gtt_any"
self.timeout_sec = 1 * 60 * 60 # Kill the job if idle for 1 hour
self.user = getpass.getuser()
- self.work_dir = f"/data/users/{self.user}"
- self.edittable_workspaces = ["forge"]
+ self.work_dir = f"/home/{self.user}"
+ self.edittable_workspaces = ["torchforge"]
self.remote_work_dir = "/packages/monarch_default_workspace/workspace/"
self.editable_workspace_paths = [
f"{self.work_dir}/{workspace}" for workspace in self.edittable_workspaces
@@ -181,7 +223,10 @@ def __init__(self, cfg: LauncherConfig | None = None):
self.job_name = self.cfg.job_name or self.create_job_name()
async def initialize(self) -> None:
- await self.launch_mast_job()
+ # HostMesh currently requires explicit configuration
+ # of the underlying transport from client to mesh.
+ # This can be removed in the future once this has been removed.
+ configure(default_transport=ChannelTransport.MetaTlsWithHostname)
async def get_allocator(self, name: str, num_hosts: int) -> tuple[Any, Any, str]:
allocator = MastAllocator(
@@ -196,10 +241,9 @@ async def get_allocator(self, name: str, num_hosts: int) -> tuple[Any, Any, str]
return allocator, alloc_constraints, self.create_server_handle()
- async def remote_setup(self, procs: ProcMesh) -> tuple[str, int]:
- setup = procs.spawn(f"setup-{uuid.uuid1()}", MastSetupActor)
+ async def remote_setup(self, procs: ProcMesh) -> None:
+ setup = procs.spawn("mast_setup", MastSetupActor)
await setup.mount.call(mount_dst="/mnt/wsfuse")
- return await setup.get_info.choose()
async def launch_mast_job(self):
handle = self.create_server_handle()
@@ -213,8 +257,8 @@ async def launch_mast_job(self):
scheduler_args={
"hpcIdentity": "hyper_monarch",
"hpcJobOncall": "monarch",
- "hpcClusterUuid": "MastProdCluster",
- "rmAttribution": "pytorch4all_clients_approved",
+ "hpcClusterUuid": "MastGenAICluster",
+ "rmAttribution": "msl_infra_hw_enab_agentrl",
},
appdef=self.build_appdef(),
workspace=Workspace(
@@ -222,31 +266,31 @@ async def launch_mast_job(self):
),
)
- await commands.get_or_create(self.job_name, config)
- return server_spec
+ job_handle = create(config, name=self.job_name)
+ print(
+ f"MAST job launched successfully:\n"
+ f"\033[92mhttps://www.internalfb.com/mlhub/pipelines/runs/mast/{self.job_name}\033[0m"
+ )
+ return job_handle
def add_additional_packages(self, packages: "Packages") -> "Packages":
packages.add_package("oil.oilfs:stable")
- packages.add_package("manifold.manifoldfs")
+ packages.add_package("manifold.manifoldfs:prod")
return packages
def build_appdef(self) -> specs.AppDef:
-
# create the app definition for the worker
- remote_end_python_path = ":".join(
- [
- f"{self.remote_work_dir}{workspace}"
- for workspace in self.editable_workspace_paths
- ]
- )
+ additional_python_paths = [
+ f"{self.remote_work_dir}{workspace}"
+ for workspace in self.editable_workspace_paths
+ ]
+ additional_python_paths.append(self.remote_work_dir)
default_envs = {
**meta_hyperactor.DEFAULT_NVRT_ENVS,
**meta_hyperactor.DEFAULT_NCCL_ENVS,
**meta_hyperactor.DEFAULT_TORCH_ENVS,
- **{
- "TORCHX_RUN_PYTHONPATH": f"{remote_end_python_path}:{self.remote_work_dir}"
- },
+ **{"TORCHX_RUN_PYTHONPATH": ":".join(additional_python_paths)},
**{
"HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS": "600",
"HYPERACTOR_CODE_MAX_FRAME_LENGTH": "1073741824",
@@ -255,11 +299,15 @@ def build_appdef(self) -> specs.AppDef:
"TORCHDYNAMO_VERBOSE": "1",
"VLLM_TORCH_COMPILE_LEVEL": "0",
"VLLM_USE_TRITON_FLASH_ATTN": "0",
+ "HF_HUB_OFFLINE": "1",
+ "TORCHSTORE_RDMA_ENABLED": "1",
+ "HF_HOME": "/mnt/wsfuse/teamforge/hf",
+ "TRANSFORMERS_OFFLINE": "1",
+ "FUSE_SRC": "ws://ws.ai.pci0ai/genai_fair_llm",
+ "FUSE_DST": "/mnt/wsfuse",
},
}
- print("DEFAULT ENVS: ", default_envs)
-
packages = Packages()
meshes = []
# Process both services and actors configurations
@@ -289,6 +337,15 @@ def build_appdef(self) -> specs.AppDef:
timeout_sec=self.timeout_sec,
env=default_envs,
)
+ appdef.metadata["mast"] = {
+ "HpcJobDefinition": {
+ "networkAffinity": {
+ # Ensure colocation
+ "preferredScope": 3, # DC
+ "fallbackScope": 3, # REGION
+ },
+ },
+ }
for role in appdef.roles:
role.resource.capabilities["server_sub_types"] = [
@@ -296,8 +353,45 @@ def build_appdef(self) -> specs.AppDef:
role.resource.capabilities["server_sub_types"][1] # GTT
]
+ # Add client role to run in MAST if in detached mode
+ if self.detached:
+ client_role = self._create_client_role(appdef)
+ appdef.roles.insert(0, client_role)
+
return appdef
+ def _create_client_role(self, appdef: specs.AppDef) -> specs.Role:
+ # Clone an existing worker role to inherit workspace configuration
+ if not appdef.roles:
+ raise ValueError(
+ "Cannot create client role: no worker roles exist to clone from"
+ )
+
+ # Clone the first worker role
+ client_role = copy.deepcopy(appdef.roles[0])
+
+ # Override with client-specific configuration
+ client_role.name = "client"
+ # Use the bootstrap script as entrypoint
+ client_role.entrypoint = "workspace/torchforge/.meta/mast/client_bootstrap.sh"
+
+ # Build args for the client role (passed to the bootstrap script)
+ # These args will be passed to client_bootstrap.sh which forwards them to main.py
+ args = [
+ "--mode=remote",
+ "--job-name",
+ self.job_name,
+ ]
+
+ # Add any extra args passed from the CLI (includes --config and other args)
+ if self.extra_args:
+ args.extend(self.extra_args)
+
+ client_role.args = args
+ client_role.num_replicas = 1
+
+ return client_role
+
def create_job_name(self):
return f"{self.user}-forge-{uuid.uuid4().hex[:6]}"
@@ -306,9 +400,15 @@ def create_server_handle(self) -> str:
def get_launcher(cfg: LauncherConfig | None = None) -> BaseLauncher | None:
- if not cfg or cfg.launcher == Launcher.SLURM:
- return Slurmlauncher()
+ if not cfg:
+ return None
+ if cfg.launcher == Launcher.SLURM:
+ return Slurmlauncher(cfg)
elif cfg.launcher == Launcher.MAST:
- return Mastlauncher(cfg)
+ if not _MAST_AVAILABLE:
+ raise ValueError(
+ "MAST imports did not succeed, cannot launch MAST jobs. Please verify your installation"
+ )
+ return MastLauncher(cfg, detached=False)
else:
raise ValueError(f"Unsupported config provided, got {cfg}")
diff --git a/src/forge/controller/proc_mesh.py b/src/forge/controller/proc_mesh.py
deleted file mode 100644
index 099826d86..000000000
--- a/src/forge/controller/proc_mesh.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-"""Spawning utils for actors and proc_meshes."""
-import logging
-
-from monarch.actor import ProcMesh
-
-from forge.controller.provisioner import (
- get_proc_mesh as _get_proc_mesh,
- stop_proc_mesh as _stop_proc_mesh,
-)
-from forge.types import ProcessConfig
-
-logger: logging.Logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-
-
-async def get_proc_mesh(process_config: ProcessConfig) -> ProcMesh:
- """Returns a proc mesh with the given process config."""
- # TODO - remove this
- return await _get_proc_mesh(process_config)
-
-
-async def stop_proc_mesh(mesh: ProcMesh) -> None:
- """Stops the given proc mesh."""
- # TODO - remove this
- return await _stop_proc_mesh(mesh)
diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py
index d66504707..a27b7d0bb 100644
--- a/src/forge/controller/provisioner.py
+++ b/src/forge/controller/provisioner.py
@@ -4,44 +4,169 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
-"""Remote resource allocation and provisioning."""
+"""Remote and local resource manager for allocation and provisioning."""
+
import asyncio
-import functools
import logging
import os
import socket
import uuid
-from typing import Optional
-from monarch._src.actor.shape import NDSlice, Shape
-from monarch.actor import HostMesh, ProcMesh, this_host
-from monarch.tools import commands
+import torch
from forge.controller.launcher import BaseLauncher, get_launcher
+from forge.env import all_env_vars, FORGE_DISABLE_METRICS
+from forge.types import ProcessConfig, ProvisionerConfig
-from forge.observability.metric_actors import get_or_create_metric_logger
+from monarch._src.actor.actor_mesh import ActorMesh
+from monarch._src.actor.shape import Extent
-from forge.types import ProcessConfig, ProvisionerConfig
+from monarch.actor import Actor, endpoint, HostMesh, ProcMesh, this_host
+
+from monarch.tools import commands
+from monarch.utils import setup_env_for_distributed
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
+def _get_port() -> str:
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("localhost", 0))
+ addr = s.getsockname()
+ port = addr[1]
+ return str(port)
+
+
+class _RemoteInfoFetcher(Actor):
+ """An actor responsible for getting remote host information."""
+
+ @endpoint
+ def get_info(self) -> tuple[str, str]:
+ """Returns hostname and port."""
+ return socket.gethostname(), _get_port()
+
+ @endpoint
+ def get_gpu_count(self) -> int:
+ """Returns the number of GPUs available on this host."""
+ try:
+ gpu_count = torch.cuda.device_count()
+ except Exception:
+ # If torch is not available or CUDA is not available, assume no GPUs
+ gpu_count = 0
+ return gpu_count
+
+
+class EnvSetter(Actor):
+ """Actor to set environment variables on each proc in a mesh.
+
+ Ideally, this is handled in spawn_procs's bootstrap call which
+ essentially does the same thing as we're doing here.
+
+ However, Monarch's SetupActor currently fails to stop on shutdown
+ which leads to zombie messages sent to the SetupActor. This is a
+ known issue, and we will move back to bootstrap once it's fixed.
+
+ We are able to avoid this here by properly awaiting the spawning
+ of the actor.
+
+ """
+
+ @endpoint
+ def set_env(self, env_vars: dict[str, str]):
+ """Set environment variables on this proc.
+
+ Args:
+ env_vars: Dictionary of environment variables to set
+ """
+ import os
+ import socket
+
+ # Set VLLM_HOST_IP (required for vLLM on multiple nodes)
+ os.environ["VLLM_HOST_IP"] = socket.gethostbyname(socket.getfqdn())
+
+ # Set user-provided environment variables
+ for k, v in env_vars.items():
+ os.environ[k] = v
+
+
+async def get_remote_info(host_mesh: HostMesh) -> tuple[str, str]:
+ """Returns the host name and port of the host mesh."""
+ throwaway_procs = host_mesh.spawn_procs(per_host={"procs": 1})
+ fetcher = throwaway_procs.spawn("_fetcher", _RemoteInfoFetcher)
+
+ # This will reduce something like extent = {"hosts": 2, "procs": 1} to
+ # {"hosts": 1, "procs": 1}.
+ singleton_slice = {k: slice(0, 1) for k in fetcher.extent.keys()}
+ fetcher = fetcher.slice(**singleton_slice)
+ # Fetcher should be a singleton at this point - call_one() will fail otherwise
+ host, port = await fetcher.get_info.call_one()
+
+ # Stopping this proc is the right thing to do, but Monarch does not yet handle manual stops well.
+ # await throwaway_procs.stop()
+ return host, port
+
+
+async def get_host_gpus(host_mesh: HostMesh) -> int:
+ """Returns the number of GPUs available on the host mesh."""
+ throwaway_procs = host_mesh.spawn_procs(per_host={"procs": 1})
+ fetcher = throwaway_procs.spawn("_gpu_counter", _RemoteInfoFetcher)
+
+ # Reduce to a singleton
+ singleton_slice = {k: slice(0, 1) for k in fetcher.extent.keys()}
+ fetcher = fetcher.slice(**singleton_slice)
+
+ gpu_count = await fetcher.get_gpu_count.call_one()
+ return gpu_count
+
+
+async def set_environment(proc_mesh: ProcMesh, env_vars: dict[str, str]):
+ """Set environment variables on a proc mesh using EnvSetter actor.
+
+ This replaces the old bootstrap approach to avoid Monarch's SetupActor
+ mesh failures on shutdown.
+
+ Args:
+ proc_mesh: The proc mesh to set environment variables on
+ env_vars: Dictionary of environment variables to set
+ """
+ env_setter = proc_mesh.spawn("_env_setter", EnvSetter)
+ await env_setter.set_env.call(env_vars)
+
+
class GpuManager:
"""Tracks and assigns GPU devices on a host.
- This currently mimics the `gpu_manager` in system_controllers - we will
- consolidate as part of the "proper HostMesh integration" work.
+ Args:
+ available_devices: Set of GPU device IDs to manage. If None, uses all devices from 0 to max_device_count-1.
+ max_device_count: Maximum number of GPU devices on this host. Defaults to 8.
"""
- def __init__(self, available_devices: set[int] | None = None):
+ def __init__(
+ self, available_devices: set[int] | None = None, max_device_count: int = 8
+ ):
if available_devices is None:
- available_devices = set(range(0, 8))
- assert all(isinstance(x, int) for x in available_devices)
- assert all(x >= 0 and x < 8 for x in available_devices)
+ available_devices = set(range(0, max_device_count))
+ else:
+ # Validate types first
+ assert all(
+ isinstance(x, int) for x in available_devices
+ ), f"All device IDs must be integers, got: {available_devices}"
+ # When available_devices is provided (e.g., from CUDA_VISIBLE_DEVICES),
+ # adjust max_device_count to accommodate the highest device ID
+ if available_devices:
+ max_device_count = max(max(available_devices) + 1, max_device_count)
+
+ assert all(
+ isinstance(x, int) for x in available_devices
+ ), f"All device IDs must be integers, got: {available_devices}"
+ assert all(
+ x >= 0 for x in available_devices
+ ), f"All device IDs must be non-negative, got: {available_devices}"
self.available_gpus = available_devices
+ self.max_device_count = max_device_count
def get_available_gpus(self) -> list[str]:
"""Returns a list of available GPU devices."""
@@ -90,15 +215,30 @@ def __init__(self, cfg: ProvisionerConfig | None = None):
f"Invalid CUDA_VISIBLE_DEVICES format: '{cuda_visible_devices}'. "
f"Expected comma-separated integers (e.g., '0,1,2'). Error: {e}"
) from e
+
+ # Get the actual GPU count for the local host
+ try:
+ local_gpu_count = torch.cuda.device_count()
+ except Exception:
+ # If torch is not available or CUDA is not available, assume no GPUs
+ local_gpu_count = 0
+
self._host_gpu_map = {
- self._this_host_id: GpuManager(available_local_devices),
+ self._this_host_id: GpuManager(
+ available_local_devices, max_device_count=local_gpu_count
+ ),
}
+ self._proc_host_map = {}
+ self._host_mesh_map = {}
self.launcher: BaseLauncher | None = get_launcher(
cfg.launcher_config if cfg is not None else None
)
if not self.launcher:
logger.warning("Launcher not provided, remote allocations will not work.")
+ self._registered_actors: list["ForgeActor"] = []
+ self._registered_services: list["ServiceInterface"] = []
+
async def initialize(self):
"""Call this after creating the instance"""
if self.launcher is not None:
@@ -116,84 +256,139 @@ async def create_host_mesh(self, name: str, num_hosts: int) -> HostMesh:
alloc, alloc_constraints, server_name = await self.launcher.get_allocator(
name, num_hosts
)
- return (
- HostMesh(
- Shape(["hosts"], NDSlice.new_row_major([num_hosts])),
- allocator=alloc,
- alloc_constraints=alloc_constraints,
- ),
- server_name,
+
+ # We are asking Monarch to allocate a single process on
+ # every host, reflected in the Extent we provide below.
+
+ # Technically, this is ["hosts", "procs"] but to reduce
+ # confusion on its relationship with procs elsewhere,
+ # we call it "no_dim".
+
+ # TODO - remove this once Monarch supports HostMesh without it.
+ host_mesh = HostMesh.allocate_nonblocking(
+ name=name,
+ extent=Extent(["hosts", "no_dim"], [num_hosts, 1]),
+ allocator=alloc,
+ alloc_constraints=alloc_constraints,
)
+ return host_mesh, server_name
+
+ def get_host_mesh(self, name: str) -> HostMesh:
+ """Returns the host mesh given its associated name.
+
+ This is currently an experimental API for HostMesh v1 and
+ should not be relied on longer term.
+
+ """
+ return self._host_mesh_map[name]
async def get_proc_mesh(
self,
num_procs: int,
with_gpus: bool = False,
num_hosts: int | None = None,
- mesh_name: Optional[str] = None,
+ mesh_name: str | None = None,
+ host_mesh: HostMesh | None = None,
+ env_vars: dict[str, str] | None = None,
+ addr: str | None = None,
+ port: str | None = None,
):
"""Gets a proc mesh.
- num_hosts = None implies that you want a local allocation, this may change.
+ Args:
+ num_procs: The number of processes to allocate.
+ with_gpus: Whether to include GPU allocations.
+ This only adds the CUDA_VISIBLE_DEVICES environment variable.
+ num_hosts: The number of hosts to allocate.
+ If this is set, a remote allocation is created.
+ If this is None, it uses the local host.
+ This behavior may change in the future.
+ host_mesh: The host mesh to allocate the process on.
+ If None, a new host mesh will be created.
+ port: The distributed port to use.
+ If None, a port will be detected.
+ addr: The distributed address to use.
+ If None, an address will be detected.
+
+ Returns:
+ A ProcMesh.
"""
+ if env_vars is None:
+ env_vars = {}
+
+ is_remote = num_hosts is not None and num_hosts > 0
+
async with self._lock:
server_name = None
- if num_hosts is not None and num_hosts > 0:
- created_hosts = len(self._server_names)
- host_mesh, server_name = await self.create_host_mesh(
- name=mesh_name,
- num_hosts=num_hosts,
- )
- host_id = uuid.uuid1()
- gpu_manager = GpuManager()
- self._host_gpu_map[host_id] = gpu_manager
- host_mesh._host_id = host_id
+ if is_remote:
+ if mesh_name is None:
+ created_hosts = len(self._server_names)
+ mesh_name = f"alloc_{created_hosts}"
+ if host_mesh is None:
+ host_mesh, server_name = await self.create_host_mesh(
+ name=mesh_name,
+ num_hosts=num_hosts,
+ )
+ host_id = uuid.uuid1()
+ # Get the GPU count from the remote host
+ remote_gpu_count = await get_host_gpus(host_mesh)
+ gpu_manager = GpuManager(max_device_count=remote_gpu_count)
+ self._host_gpu_map[host_id] = gpu_manager
+ host_mesh._host_id = host_id
+ else:
+ host_id = host_mesh._host_id
+ gpu_manager = self._host_gpu_map[host_id]
else:
+ # fallback to local
host_mesh = this_host()
gpu_manager = self._host_gpu_map[self._this_host_id]
host_mesh._host_id = self._this_host_id
if with_gpus:
- # The ideal path here:
- # - Create a host mesh
- # - Grab a host from host mesh, from proc 0 spawn an actor that
- # gets addr/port
- # - Spawn procs on the HostMesh with addr/port, setting the
- # addr/port in bootstrap.
- # We can't currently do this because HostMesh only supports single
- # proc_mesh creation at the moment. This will be possible once
- # we have "proper HostMesh support".
- def bootstrap(gpu_ids: list[str]):
- # This works for single host, needed for vLLM currently.
- import os
-
- os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(gpu_ids)
- os.environ["MASTER_ADDR"] = socket.gethostname()
- # Multiple actors trying to call _get_port doesn't work
- # os.environ["MASTER_PORT"] = _get_port()
-
- # Setting the last digit to the first GPU id allows us to i.e.
- # create multiple vLLM instances on the same local host.
- os.environ["MASTER_PORT"] = f"1234{gpu_ids[0]}"
- os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600"
- os.environ["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824"
-
+ if not addr or not port:
+ addr, port = await get_remote_info(host_mesh)
gpu_ids = gpu_manager.get_gpus(num_procs)
- procs = host_mesh.spawn_procs(
- per_host={"gpus": num_procs},
- bootstrap=functools.partial(bootstrap, gpu_ids=gpu_ids),
+
+ env_vars["MASTER_ADDR"] = addr
+ env_vars["MASTER_PORT"] = port
+
+ # Set the PTD world size
+ world_size = num_procs * (num_hosts or 1)
+ env_vars["WORLD_SIZE"] = str(world_size)
+ env_vars["CUDA_VISIBLE_DEVICES"] = ",".join(gpu_ids)
+
+ # Inherit Forge-relevant environment variables from the system
+ for env_var in all_env_vars():
+ env_vars[env_var.name] = str(env_var.get_value())
+
+ # Spawn procs without bootstrap to avoid SetupActor mesh failures
+ procs = host_mesh.spawn_procs(
+ per_host={"procs": num_procs},
+ name=mesh_name,
+ )
+
+ # Set up environment variables (replaces old bootstrap)
+ if env_vars:
+ await set_environment(procs, env_vars)
+
+ # Set up PyTorch distributed environment if using GPUs
+ if with_gpus:
+ await setup_env_for_distributed(
+ procs,
+ master_addr=addr,
+ master_port=int(port),
)
- # Pick a random host/port, we'll feed this in afterwards
- # Once we have true HostMesh support, we can do this on proc 0 of each host
- # then spin up the proc meshes with the environment afterwards.
- hostname, port = await self.launcher.remote_setup(procs)
- procs._hostname = hostname
- procs._port = port
+
+ if is_remote:
+ await self.launcher.remote_setup(procs)
+
+ # Tag the proc mesh with additional metadata for our own cleanup later
+ if with_gpus:
+ # Applies any launcher specific remote setup.
procs._gpu_ids = gpu_ids
- else:
- procs = host_mesh.spawn_procs(per_host={"gpus": num_procs})
+ self._host_mesh_map[mesh_name] = host_mesh
procs._host = host_mesh
# If we created a server, track so we can tear it down later.
@@ -201,18 +396,40 @@ def bootstrap(gpu_ids: list[str]):
self._server_names.append(server_name)
self._proc_server_map[procs] = server_name
- # Spawn local logging actor on each process and register with global logger
- _ = await get_or_create_metric_logger(procs)
+ self._proc_host_map[procs] = host_mesh
+ # Spawn LocalFetcherActor for this ProcMesh and register with GlobalLoggingActor.
+ # When called, the LocalFetcherActor is broadcast by Monarch to all ranks in the ProcMesh.
+ if not FORGE_DISABLE_METRICS.get_value():
+ from forge.observability.metric_actors import get_or_create_metric_logger
+
+ _ = await get_or_create_metric_logger(procs, process_name=mesh_name)
return procs
+ async def host_mesh_from_proc(self, proc_mesh: ProcMesh):
+ if proc_mesh not in self._proc_host_map:
+ raise ValueError(
+ "The proc mesh was not allocated with an associated hostmesh."
+ )
+ return self._proc_host_map[proc_mesh]
+
async def stop_proc_mesh(self, proc_mesh: ProcMesh):
"""Stops a proc mesh."""
+ if proc_mesh not in self._proc_host_map:
+ logger.warning(
+ f"proc mesh {proc_mesh} was requested to be stopped, but was either already stopped or "
+ "was never registered with the provisioner."
+ )
+ return
async with self._lock:
- # Deregister local logger from global logger
- if hasattr(proc_mesh, "_local_fetcher"):
+ # Deregister LocalFetcherActor from GlobalLoggingActor
+ if hasattr(proc_mesh, "_local_fetcher") and hasattr(proc_mesh, "_uid"):
+ from forge.observability.metric_actors import (
+ get_or_create_metric_logger,
+ )
+
global_logger = await get_or_create_metric_logger(proc_mesh)
- await global_logger.deregister_fetcher.call_one(proc_mesh)
+ await global_logger.deregister_fetcher.call_one(proc_mesh._uid)
if hasattr(proc_mesh, "_gpu_ids"):
gpu_manager = self._host_gpu_map[proc_mesh._host._host_id]
@@ -221,9 +438,57 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh):
if proc_mesh in self._proc_server_map:
server_name = self._proc_server_map[proc_mesh]
commands.kill(server_name)
+ del self._proc_host_map[proc_mesh]
+
+ def register_service(self, service: "ServiceInterface") -> None:
+ """Registers a service allocation for cleanup."""
+ # Import ServiceInterface here instead of at top-level to avoid circular import
+ from forge.controller.service import ServiceInterface
+
+ if not isinstance(service, ServiceInterface):
+ raise TypeError(
+ f"register_service expected ServiceInterface, got {type(service)}"
+ )
+
+ self._registered_services.append(service)
+
+ def register_actor(self, actor: "ForgeActor") -> None:
+ """Registers a single actor allocation for cleanup."""
+
+ if not isinstance(actor, ActorMesh):
+ raise TypeError(f"register_actor expected ActorMesh, got {type(actor)}")
+
+ self._registered_actors.append(actor)
+
+ async def shutdown_all_allocations(self):
+ """Gracefully shut down all tracked actors and services."""
+ logger.info(
+ f"Shutting down {len(self._registered_services)} service(s) and {len(self._registered_actors)} actor(s)..."
+ )
+ # --- ServiceInterface ---
+ for service in reversed(self._registered_services):
+ try:
+ await service.shutdown()
+
+ except Exception as e:
+ logger.warning(f"Failed to shut down {service}: {e}")
+
+ # --- Actor instance (ForgeActor or underlying ActorMesh) ---
+ for actor in reversed(self._registered_actors):
+ try:
+ # Get the class to call shutdown on (ForgeActor or its bound class)
+ actor_cls = getattr(actor, "_class", None) or actor.__class__
+ await actor_cls.shutdown(actor)
+
+ except Exception as e:
+ logger.warning(f"Failed to shut down {actor}: {e}")
+
+ self._registered_actors.clear()
+ self._registered_services.clear()
async def shutdown(self):
"""Tears down all remaining remote allocations."""
+ await self.shutdown_all_allocations()
async with self._lock:
for server_name in self._server_names:
commands.kill(server_name)
@@ -246,22 +511,88 @@ async def _get_provisioner():
return _provisioner
-async def get_proc_mesh(config: ProcessConfig) -> ProcMesh:
+async def get_proc_mesh(
+ process_config: ProcessConfig,
+ host_mesh: HostMesh | None = None,
+ env_vars: dict[str, str] | None = None,
+ port: str | None = None,
+ addr: str | None = None,
+) -> ProcMesh:
+ """Returns a proc mesh from the provisioner.
+
+ Args:
+ process_config: The process config.
+ host_mesh: The host mesh to allocate the process on.
+ If None, a new host mesh will be created.
+ port: The distributed port to use.
+ If None, a port will be detected.
+ addr: The distributed address to use.
+ If None, an address will be detected.
+
+ Returns:
+ A proc mesh.
+
+ """
provisioner = await _get_provisioner()
return await provisioner.get_proc_mesh(
- num_procs=config.procs,
- with_gpus=config.with_gpus,
- num_hosts=config.hosts,
- mesh_name=config.mesh_name,
+ num_procs=process_config.procs,
+ with_gpus=process_config.with_gpus,
+ num_hosts=process_config.hosts,
+ mesh_name=process_config.mesh_name,
+ host_mesh=host_mesh,
+ env_vars=env_vars,
+ port=port,
+ addr=addr,
)
+async def host_mesh_from_proc(proc_mesh: ProcMesh):
+ """Returns the host mesh that allocated the original proc_mesh.
+
+ This functionality will be enabled in Monarch, so this is a temporary
+ API.
+
+ """
+ provisioner = await _get_provisioner()
+ return await provisioner.host_mesh_from_proc(proc_mesh)
+
+
+async def register_service(service: "ServiceInterface") -> None:
+ """Registers a service allocation with the global provisioner."""
+ provisioner = await _get_provisioner()
+ provisioner.register_service(service)
+
+
+async def register_actor(actor: "ForgeActor") -> None:
+ """Registers an actor allocation with the global provisioner."""
+ provisioner = await _get_provisioner()
+ provisioner.register_actor(actor)
+
+
async def stop_proc_mesh(proc_mesh: ProcMesh):
provisioner = await _get_provisioner()
return await provisioner.stop_proc_mesh(proc_mesh=proc_mesh)
+async def shutdown_metric_logger():
+ """Shutdown the global metric logger and all its backends."""
+ from forge.observability.metric_actors import get_or_create_metric_logger
+
+ logger.info("Shutting down metric logger...")
+ try:
+ mlogger = await get_or_create_metric_logger()
+ await mlogger.shutdown.call_one()
+ except Exception as e:
+ logger.warning(f"Failed to shutdown metric logger: {e}")
+
+
async def shutdown():
+ await shutdown_metric_logger()
+
logger.info("Shutting down provisioner..")
+
provisioner = await _get_provisioner()
- return await provisioner.shutdown()
+ result = await provisioner.shutdown()
+
+ logger.info("Shutdown completed successfully")
+ return result
diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py
index 5b7e2f884..c64d5c3f3 100644
--- a/src/forge/controller/service/interface.py
+++ b/src/forge/controller/service/interface.py
@@ -14,7 +14,7 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
-from typing import Dict, Generic, List, ParamSpec, TypeVar
+from typing import Generic, ParamSpec, TypeVar
from monarch._src.actor.endpoint import EndpointProperty
@@ -96,7 +96,7 @@ async def route(self, *args: P.args, **kwargs: P.kwargs) -> R:
sess_id = kwargs.pop("sess_id", None)
return await self.service._call(sess_id, self.endpoint_name, *args, **kwargs)
- async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]:
+ async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> list[R]:
"""Broadcasts a request to all healthy replicas and returns the results as a list."""
result = await self.service.call_all(self.endpoint_name, *args, **kwargs)
return result
@@ -107,7 +107,7 @@ async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R:
"Services only support route() and fanout()."
)
- async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]:
+ async def call(self, *args: P.args, **kwargs: P.kwargs) -> list[R]:
raise NotImplementedError(
"You tried to use call() on a service, not an actor. "
"Services only support route() and fanout()."
@@ -119,7 +119,7 @@ async def call_one(self, *args: P.args, **kwargs: P.kwargs) -> R:
"Services only support route() and fanout()."
)
- async def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> List[R]:
+ async def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> list[R]:
raise NotImplementedError(
"You tried to use broadcast() on a service, not an actor. "
"Services only support route() and fanout()."
@@ -157,7 +157,7 @@ async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R:
sess_id, self.endpoint_name, *args, **kwargs
)
- async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]:
+ async def call(self, *args: P.args, **kwargs: P.kwargs) -> list[R]:
"""Broadcasts a request to all healthy replicas and returns the results as a list."""
result = await self.actor_mesh.call_all.call_one(
self.endpoint_name, *args, **kwargs
@@ -314,9 +314,9 @@ class Router(ABC):
@abstractmethod
def get_replica(
self,
- healthy_replicas: List[Replica],
+ healthy_replicas: list[Replica],
sess_id: str | None = None,
- session_map: Dict[str, int] | None = None,
+ session_map: dict[str, int] | None = None,
) -> Replica:
"""Select a replica from the list based on routing logic."""
pass
diff --git a/src/forge/controller/service/metrics.py b/src/forge/controller/service/metrics.py
index d328728bd..728d7d57a 100644
--- a/src/forge/controller/service/metrics.py
+++ b/src/forge/controller/service/metrics.py
@@ -12,7 +12,6 @@
"""
from dataclasses import dataclass, field
-from typing import Dict, List
from forge.controller.service.replica import ReplicaMetrics
@@ -35,7 +34,7 @@ class ServiceMetrics:
"""
# Replica metrics
- replica_metrics: Dict[int, ReplicaMetrics] = field(default_factory=dict)
+ replica_metrics: dict[int, ReplicaMetrics] = field(default_factory=dict)
# Service-level metrics
total_sessions: int = 0
healthy_replicas: int = 0
@@ -50,7 +49,7 @@ def get_total_request_rate(self, window_seconds: float = 60.0) -> float:
for metrics in self.replica_metrics.values()
)
- def get_avg_queue_depth(self, replicas: List) -> float:
+ def get_avg_queue_depth(self, replicas: list) -> float:
"""Get average queue depth across all healthy replicas."""
healthy_replicas = [r for r in replicas if r.healthy]
if not healthy_replicas:
@@ -58,7 +57,7 @@ def get_avg_queue_depth(self, replicas: List) -> float:
total_queue_depth = sum(r.request_queue.qsize() for r in healthy_replicas)
return total_queue_depth / len(healthy_replicas)
- def get_avg_capacity_utilization(self, replicas: List) -> float:
+ def get_avg_capacity_utilization(self, replicas: list) -> float:
"""Get average capacity utilization across all healthy replicas."""
healthy_replicas = [r for r in replicas if r.healthy]
if not healthy_replicas:
diff --git a/src/forge/controller/service/replica.py b/src/forge/controller/service/replica.py
index dfdb10169..2cfc426bb 100644
--- a/src/forge/controller/service/replica.py
+++ b/src/forge/controller/service/replica.py
@@ -11,13 +11,12 @@
from collections import deque
from dataclasses import dataclass, field
from enum import Enum
-from typing import Optional
-
-from monarch.actor import ActorError
from forge.controller import ForgeActor
from forge.types import ProcessConfig
+from monarch.actor import ActorError
+
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@@ -81,7 +80,7 @@ class ServiceRequest:
"""
- session_id: Optional[str]
+ session_id: str | None
function: str
args: tuple
kwargs: dict
@@ -107,7 +106,7 @@ class Replica:
actor_kwargs: dict
# The Actor that this replica is running
- actor: Optional[ForgeActor] = None
+ actor: ForgeActor | None = None
# Async queue for incoming requests
request_queue: asyncio.Queue[ServiceRequest] = field(default_factory=asyncio.Queue)
@@ -127,10 +126,10 @@ class Replica:
return_first_rank_result: bool = False
# Recovery-related state
- _recovery_task: Optional[asyncio.Task] = None
+ _recovery_task: asyncio.Task | None = None
# Run task is the replica's event loop
- _run_task: Optional[asyncio.Task] = None
+ _run_task: asyncio.Task | None = None
# Metrics tracking
metrics: ReplicaMetrics = field(default_factory=ReplicaMetrics)
@@ -159,10 +158,13 @@ async def initialize(self):
# Deploy the actor and its underlying resources
logger.debug(f"Launching actor for replica {self.idx}")
- mesh_name_with_replica = f"{self.proc_config.mesh_name}_{self.idx}"
- self.proc_config.mesh_name = mesh_name_with_replica
- if hasattr(self.actor_def, "mesh_name"):
- self.actor_def.mesh_name = mesh_name_with_replica
+ # If a Mesh name was specified, incorporate this info.
+ if self.proc_config.mesh_name:
+ mesh_name_with_replica = f"{self.proc_config.mesh_name}_{self.idx}"
+ self.proc_config.mesh_name = mesh_name_with_replica
+ if hasattr(self.actor_def, "mesh_name"):
+ self.actor_def.mesh_name = mesh_name_with_replica
+
self.actor = await self.actor_def.launch(
*self.actor_args,
**self.actor_kwargs,
diff --git a/src/forge/controller/service/router.py b/src/forge/controller/service/router.py
index 502402e36..a53d9c873 100644
--- a/src/forge/controller/service/router.py
+++ b/src/forge/controller/service/router.py
@@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.
import logging
-from typing import Dict, List
from .interface import Router
from .replica import Replica
@@ -22,9 +21,9 @@ def __init__(self):
def get_replica(
self,
- healthy_replicas: List[Replica],
+ healthy_replicas: list[Replica],
sess_id: str | None = None,
- session_map: Dict[str, int] | None = None,
+ session_map: dict[str, int] | None = None,
) -> Replica:
if not healthy_replicas:
raise RuntimeError("No healthy replicas available for load balancing")
@@ -40,9 +39,9 @@ class LeastLoadedRouter(Router):
def get_replica(
self,
- healthy_replicas: List[Replica],
+ healthy_replicas: list[Replica],
sess_id: str | None = None,
- session_map: Dict[str, int] | None = None,
+ session_map: dict[str, int] | None = None,
) -> Replica:
if not healthy_replicas:
raise RuntimeError("No healthy replicas available for session assignment")
@@ -57,9 +56,9 @@ def __init__(self, fallback_router: Router):
def get_replica(
self,
- healthy_replicas: List[Replica],
+ healthy_replicas: list[Replica],
sess_id: str | None = None,
- session_map: Dict[str, int] | None = None,
+ session_map: dict[str, int] | None = None,
) -> Replica:
if sess_id is None:
raise ValueError("SessionRouter requires a session ID")
diff --git a/src/forge/controller/service/service.md b/src/forge/controller/service/service.md
deleted file mode 100644
index 3a8134fa3..000000000
--- a/src/forge/controller/service/service.md
+++ /dev/null
@@ -1,301 +0,0 @@
-# Service - Distributed Actor Service Controller
-
-A robust service orchestration system for managing distributed actor-based workloads with fault tolerance and intelligent load balancing.
-
-## Overview
-
-The Service class provides a unified interface for deploying and managing multiple replicas of actor-based services across distributed compute resources. It automatically handles replica lifecycle, request routing, and session management.
-
-## Key Features
-
-### **Fault Tolerance**
-- **Health Monitoring**: Continuous health checks with automatic replica recovery
-- **Request Migration**: Seamless migration of requests from failed replicas
-- **Session Preservation**: Maintains session state during replica failures
-- **Graceful Degradation**: Continues operation with reduced capacity
-
-### **Load Balancing**
-- **Round-Robin**: Default load distribution across healthy replicas
-- **Least-Loaded**: Session assignment to replicas with lowest load
-- **Session Affinity**: Sticky sessions for stateful workloads
-- **Custom Routing**: Extensible routing logic for specialized use cases
-
-### **Comprehensive Metrics**
-- **Request Metrics**: Throughput, latency, success/failure rates
-- **Capacity Metrics**: Utilization, queue depth, active requests
-- **Service Metrics**: Session counts, replica health, scaling events
-- **Real-time Monitoring**: Sliding window metrics for responsive scaling
-
-### **Session Management**
-- **Context-Aware Sessions**: Automatic session context propagation
-- **Session Lifecycle**: Managed session creation and cleanup
-- **Routing Hints**: Custom session routing based on workload characteristics
-
-## Architecture
-
-```
-┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐
-│ Client API │───▶│ Service Layer │───▶│ Replica Pool │
-└─────────────────┘ └──────────────────┘ └─────────────────┘
- │ │
- ▼ ▼
- ┌──────────────┐ ┌─────────────┐
- │ Autoscaler │ │ Actor Mesh │
- └──────────────┘ └─────────────┘
- │ │
- ▼ ▼
- ┌──────────────┐ ┌─────────────┐
- │ Metrics │ │ Health │
- │ Collector │ │ Monitor │
- └──────────────┘ └─────────────┘
-```
-
-## Usage
-
-### Basic Service Setup
-
-```python
-from forge.controller.service import Service, ServiceConfig
-
-# Configure service parameters
-config = ServiceConfig(
- gpus_per_replica=1,
- min_replicas=2,
- max_replicas=10,
- default_replicas=3,
- replica_max_concurrent_requests=10,
-)
-
-# Create service with your actor definition
-service = Service(config, MyActorClass, *actor_args, **actor_kwargs)
-await service.__initialize__()
-```
-
-### Session-Based Calls
-
-```python
-# Context manager for session lifecycle
-async with service.session() as session:
- result1 = await service.my_endpoint(arg1, arg2)
- result2 = await service.another_endpoint(arg3)
- # Session automatically terminated on exit
-
-# Manual session management
-session_id = await service.start_session()
-result = await service.my_endpoint(session_id, arg1, arg2)
-await service.terminate_session(session_id)
-```
-
-### Stateless Calls
-
-```python
-# Direct calls without sessions (uses round-robin load balancing)
-result = await service.my_endpoint(arg1, arg2)
-```
-
-### Custom Routing
-
-```python
-# Override _custom_replica_routing for specialized routing logic
-class CustomService(Service):
- async def _custom_replica_routing(self, sess_id: str | None, **kwargs) -> Optional[Replica]:
- # Custom routing based on request characteristics
- if kwargs.get('priority') == 'high':
- return self._get_least_loaded_replica()
- return None # Fall back to default routing
-
-# Use with routing hints
-async with service.session(priority='high') as session:
- result = await service.my_endpoint(arg1, arg2)
-```
-
-### Monitoring and Metrics
-
-```python
-# Get detailed metrics
-metrics = service.get_metrics()
-print(f"Total request rate: {metrics.get_total_request_rate()}")
-print(f"Average queue depth: {metrics.get_avg_queue_depth()}")
-print(f"Capacity utilization: {metrics.get_avg_capacity_utilization(service._replicas)}")
-
-# Get summary for monitoring dashboards
-summary = service.get_metrics_summary()
-print(f"Healthy replicas: {summary['service']['healthy_replicas']}")
-print(f"Total sessions: {summary['service']['total_sessions']}")
-
-# Per-replica metrics
-for replica_idx, replica_metrics in summary['replicas'].items():
- print(f"Replica {replica_idx}: {replica_metrics['request_rate']:.1f} req/s")
-```
-
-### Graceful Shutdown
-
-```python
-# Stop the service and all replicas
-await service.stop()
-```
-
-## Configuration
-
-### ServiceConfig
-
-| Parameter | Type | Description |
-|-----------|------|-------------|
-| `gpus_per_replica` | int | Number of GPUs allocated per replica |
-| `min_replicas` | int | Minimum number of replicas to maintain |
-| `max_replicas` | int | Maximum number of replicas allowed |
-| `default_replicas` | int | Initial number of replicas to start |
-| `replica_max_concurrent_requests` | int | Maximum concurrent requests per replica |
-| `health_poll_rate` | float | Health check frequency in seconds |
-| `return_first_rank_result` | bool | Auto-unwrap ValueMesh to first rank's result |
-| `autoscaling` | AutoscalingConfig | Autoscaling configuration |
-
-### AutoscalingConfig
-
-#### Scale Up Triggers
-| Parameter | Default | Description |
-|-----------|---------|-------------|
-| `scale_up_queue_depth_threshold` | 5.0 | Average queue depth to trigger scale up |
-| `scale_up_capacity_threshold` | 0.8 | Capacity utilization to trigger scale up |
-| `scale_up_request_rate_threshold` | 10.0 | Requests/sec to trigger scale up |
-
-#### Scale Down Triggers
-| Parameter | Default | Description |
-|-----------|---------|-------------|
-| `scale_down_capacity_threshold` | 0.3 | Capacity utilization to trigger scale down |
-| `scale_down_queue_depth_threshold` | 1.0 | Average queue depth to trigger scale down |
-| `scale_down_idle_time_threshold` | 300.0 | Seconds of low utilization before scale down |
-
-#### Timing Controls
-| Parameter | Default | Description |
-|-----------|---------|-------------|
-| `min_time_between_scale_events` | 60.0 | Minimum seconds between scaling events |
-| `scale_up_cooldown` | 30.0 | Cooldown after scale up |
-| `scale_down_cooldown` | 120.0 | Cooldown after scale down |
-
-#### Scaling Behavior
-| Parameter | Default | Description |
-|-----------|---------|-------------|
-| `scale_up_step_size` | 1 | How many replicas to add at once |
-| `scale_down_step_size` | 1 | How many replicas to remove at once |
-
-#### Safety Limits
-| Parameter | Default | Description |
-|-----------|---------|-------------|
-| `max_queue_depth_emergency` | 20.0 | Emergency scale up threshold |
-| `min_healthy_replicas_ratio` | 0.5 | Minimum ratio of healthy replicas |
-
-## Metrics
-
-### Service-Level Metrics
-- **Total Sessions**: Number of active sessions
-- **Healthy Replicas**: Number of operational replicas
-- **Total Request Rate**: Requests per second across all replicas
-- **Average Queue Depth**: Average pending requests per replica
-- **Average Capacity Utilization**: Average resource usage across replicas
-- **Sessions Per Replica**: Distribution of sessions across replicas
-
-### Replica-Level Metrics
-- **Request Counts**: Total, successful, and failed requests
-- **Request Rate**: Requests per second (sliding window)
-- **Average Latency**: Response time (sliding window)
-- **Active Requests**: Currently processing requests
-- **Queue Depth**: Pending requests in queue
-- **Assigned Sessions**: Number of sessions assigned to replica
-- **Capacity Utilization**: Current load vs maximum capacity
-
-## Use Cases
-
-### ML Model Serving
-```python
-# High-throughput model inference with automatic scaling
-config = ServiceConfig(
- gpus_per_replica=1,
- min_replicas=2,
- max_replicas=20,
- default_replicas=4,
- replica_max_concurrent_requests=8,
- autoscaling=AutoscalingConfig(
- scale_up_capacity_threshold=0.7,
- scale_up_queue_depth_threshold=3.0
- )
-)
-service = Service(config, ModelInferenceActor, model_path="/path/to/model")
-```
-
-### Batch Processing
-```python
-# Parallel job execution with fault tolerance
-config = ServiceConfig(
- gpus_per_replica=2,
- min_replicas=1,
- max_replicas=10,
- default_replicas=3,
- replica_max_concurrent_requests=5,
- autoscaling=AutoscalingConfig(
- scale_up_queue_depth_threshold=10.0,
- scale_down_idle_time_threshold=600.0
- )
-)
-service = Service(config, BatchProcessorActor, batch_size=100)
-```
-
-### Real-time Analytics
-```python
-# Stream processing with session affinity
-config = ServiceConfig(
- gpus_per_replica=1,
- min_replicas=3,
- max_replicas=15,
- default_replicas=5,
- replica_max_concurrent_requests=20,
- autoscaling=AutoscalingConfig(
- scale_up_request_rate_threshold=50.0,
- scale_up_capacity_threshold=0.6
- )
-)
-service = Service(config, StreamProcessorActor, window_size=1000)
-```
-
-## Performance Characteristics
-
-- **Low Latency**: Sub-millisecond request routing overhead
-- **High Throughput**: Concurrent request processing across replicas
-- **Elastic Scaling**: Responsive to traffic patterns with configurable thresholds
-- **Resource Efficient**: Intelligent replica management and load balancing
-- **Fault Resilient**: Automatic recovery from replica failures
-- **Session Aware**: Maintains state consistency for stateful workloads
-
-## Best Practices
-
-### Configuration
-- Set `min_replicas` based on baseline load requirements
-- Configure `max_replicas` based on resource constraints
-- Tune autoscaling thresholds based on workload characteristics
-- Use appropriate cooldown periods to prevent scaling oscillation
-
-### Session Management
-- Use sessions for stateful workloads requiring consistency
-- Prefer stateless calls for better load distribution
-- Implement custom routing for specialized workload requirements
-
-### Monitoring
-- Monitor key metrics: request rate, queue depth, capacity utilization
-- Set up alerts for unhealthy replicas and scaling events
-- Track session distribution for load balancing effectiveness
-
-### Error Handling
-- Implement proper error handling in actor endpoints
-- Use try-catch blocks around service calls
-- Monitor failed request rates for service health
-
-## Dependencies
-
-- `monarch.actor`: Actor framework for distributed computing
-- `recoverable_mesh`: Fault-tolerant process mesh management
-- `asyncio`: Asynchronous I/O support
-- `contextvars`: Context variable support for session management
-
-## Thread Safety
-
-The Service class is designed for use in asyncio environments and is not thread-safe. All operations should be performed within the same event loop.
diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py
index 0b655fb6a..9e843610f 100644
--- a/src/forge/controller/service/service.py
+++ b/src/forge/controller/service/service.py
@@ -36,9 +36,6 @@
import logging
import pprint
import uuid
-from typing import Dict, List
-
-from monarch.actor import Actor, endpoint
from forge.controller.service.interface import _session_context, Session
@@ -52,6 +49,8 @@
)
from forge.types import ServiceConfig
+from monarch.actor import Actor, endpoint
+
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@@ -68,13 +67,6 @@ class Service:
actor_def: Actor class definition to instantiate on each replica
*actor_args: Positional arguments passed to actor constructor
**actor_kwargs: Keyword arguments passed to actor constructor
-
-
- Attributes:
- _cfg: Service configuration
- _replicas: List of managed replica instances
- _active_sessions: Currently active sessions
- _metrics: Aggregated service and replica metrics
"""
def __init__(
@@ -92,7 +84,7 @@ def __init__(
self._active_sessions = []
self._id_session_map = {}
- self._session_replica_map: Dict[str, int] = {}
+ self._session_replica_map: dict[str, int] = {}
# Initialize metrics collection
self._metrics = ServiceMetrics()
@@ -196,7 +188,7 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs):
)
raise
- async def call_all(self, function: str, *args, **kwargs) -> List:
+ async def call_all(self, function: str, *args, **kwargs) -> list:
"""
Broadcasts a function call to all healthy replicas and returns results as a list.
@@ -293,8 +285,8 @@ async def _migrate_remaining_requests(self, failed_replica: Replica):
return
# Distribute requests among healthy replicas
- for i, request in enumerate(migrated_requests):
- target_replica = healthy_replicas[i % len(healthy_replicas)]
+ for request in migrated_requests:
+ target_replica = self._default_router.get_replica(healthy_replicas)
await target_replica.enqueue_request(request)
# Update session mapping if needed
@@ -486,6 +478,10 @@ async def _get_replica(self, sess_id: str | None) -> "Replica":
)
async def stop(self):
+ """
+ Stops the service and all managed replicas.
+ This method should be called when the service is no longer needed.
+ """
logger.debug("Stopping service...")
# Signal shutdown to health loop
self._shutdown_requested = True
@@ -605,12 +601,6 @@ class ServiceActor(Actor):
actor_def: Actor class definition to instantiate on each replica
*actor_args: Positional arguments passed to actor constructor
**actor_kwargs: Keyword arguments passed to actor constructor
-
- Attributes:
- _cfg: Service configuration
- _replicas: List of managed replica instances
- _active_sessions: Currently active sessions
- _metrics: Aggregated service and replica metrics
"""
def __init__(self, cfg: ServiceConfig, actor_def, actor_kwargs: dict):
@@ -622,7 +612,7 @@ def __init__(self, cfg: ServiceConfig, actor_def, actor_kwargs: dict):
self._active_sessions = []
self._id_session_map = {}
- self._session_replica_map: Dict[str, int] = {}
+ self._session_replica_map: dict[str, int] = {}
self._next_replica_idx = 0 # For round-robin load balancing
# Initialize metrics collection
@@ -726,7 +716,7 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs):
raise
@endpoint
- async def call_all(self, function: str, *args, **kwargs) -> List:
+ async def call_all(self, function: str, *args, **kwargs) -> list:
"""
Broadcasts a function call to all healthy replicas and returns results as a list.
diff --git a/src/forge/controller/service/spawn.py b/src/forge/controller/service/spawn.py
index c7c5cf29f..bf3614956 100644
--- a/src/forge/controller/service/spawn.py
+++ b/src/forge/controller/service/spawn.py
@@ -8,13 +8,13 @@
import logging
from typing import Type
-from monarch.actor import proc_mesh
-
from forge.controller import ForgeActor
from forge.controller.service import ServiceActor, ServiceConfig
from forge.controller.service.interface import ServiceInterfaceV2
+from monarch.actor import proc_mesh
+
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
diff --git a/src/forge/controller/system_controllers/__init__.py b/src/forge/controller/system_controllers/__init__.py
deleted file mode 100644
index dd2c4abca..000000000
--- a/src/forge/controller/system_controllers/__init__.py
+++ /dev/null
@@ -1,12 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-from .gpu_manager import get_gpu_ids, release_gpus
-
-__all__ = [
- "get_gpu_ids",
- "release_gpus",
-]
diff --git a/src/forge/controller/system_controllers/gpu_manager.py b/src/forge/controller/system_controllers/gpu_manager.py
deleted file mode 100644
index cbb1bccda..000000000
--- a/src/forge/controller/system_controllers/gpu_manager.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-"""Implements an actor responsible for tracking and assigning GPU devices on HostMesh."""
-
-import logging
-
-from monarch.actor import Actor, ActorError, endpoint, get_or_spawn_controller
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.DEBUG)
-
-
-class GpuManager(Actor):
- """An actor that tracks and assigns GPU devices on given HostMeshes."""
-
- def __init__(self):
- # TODO - extend this to support multiple HostMeshes too
- self.available_gpus = set(range(0, 8))
-
- @endpoint
- def get_available_gpus(self) -> list[str]:
- """Returns a list of available GPU devices."""
- return [str(gpu) for gpu in self.available_gpus]
-
- @endpoint
- def get_gpus(self, num_gpus: int) -> list[str]:
- """Assigns GPU devices."""
- if num_gpus > len(self.available_gpus):
- raise RuntimeError("Not enough GPUs available")
- gpus = list(self.available_gpus)[:num_gpus]
- self.available_gpus -= set(gpus)
- return [str(gpu) for gpu in gpus]
-
- @endpoint
- def release_gpus(self, gpu_ids: list[str]) -> None:
- """Releases the given GPU devices."""
- for gpu_id in gpu_ids:
- self.available_gpus.add(int(gpu_id))
-
- def __repr__(self) -> str:
- return "GpuManager"
-
-
-async def get_gpu_manager() -> GpuManager:
- """Gets the singleton GPU manager actor."""
- try:
- return await get_or_spawn_controller("gpu_manager", GpuManager)
- except ActorError as e:
- raise e.exception from e
-
-
-async def get_gpu_ids(num_gpus: int) -> list[str]:
- """Gets GPU IDs for the given number of GPUs."""
- try:
- gpu_manager = await get_or_spawn_controller("gpu_manager", GpuManager)
- return await gpu_manager.get_gpus.call_one(num_gpus)
- except ActorError as e:
- # Raise the underlying error instead of the Monarch error
- raise e.exception from e
-
-
-async def release_gpus(gpu_ids: list[str]) -> None:
- """Releases the given GPU IDs."""
- try:
- gpu_manager = await get_or_spawn_controller("gpu_manager", GpuManager)
- await gpu_manager.release_gpus.call_one(gpu_ids)
- except ActorError as e:
- # Raise the underlying error instead of the Monarch error
- raise e.exception from e
diff --git a/src/forge/data/__init__.py b/src/forge/data/__init__.py
index 4347199b9..6564817a9 100644
--- a/src/forge/data/__init__.py
+++ b/src/forge/data/__init__.py
@@ -3,7 +3,14 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
-from .collate import collate_packed
+from .collate import collate_packed, collate_padded
+from .metric_transform import DefaultDatasetMetricTransform, MetricTransform
from .utils import CROSS_ENTROPY_IGNORE_IDX
-__all__ = ["collate_packed", "CROSS_ENTROPY_IGNORE_IDX"]
+__all__ = [
+ "collate_packed",
+ "collate_padded",
+ "CROSS_ENTROPY_IGNORE_IDX",
+ "MetricTransform",
+ "DefaultDatasetMetricTransform",
+]
diff --git a/src/forge/data/collate.py b/src/forge/data/collate.py
index ddd0a7519..15e086e51 100644
--- a/src/forge/data/collate.py
+++ b/src/forge/data/collate.py
@@ -7,6 +7,72 @@
from typing import Any, Callable
import torch
+import torch.nn.functional as F
+
+from forge.data.utils import CROSS_ENTROPY_IGNORE_IDX
+
+
+def collate_padded(batch: list[dict[str, Any]]) -> dict[str, Any]:
+ """
+ Collate function that pads sequences to the longest sample in the batch.
+
+ Handles any tensor keys by padding to the longest
+ sequence for that key. Uses 0 as default padding value, and
+ CROSS_ENTROPY_IGNORE_IDX (-100) for 'labels' keys.
+
+ Non-tensor fields are collected into lists. The 'metrics' field is
+ special-cased to be flattened (extended) rather than nested.
+
+ Args:
+ batch: List of samples, each containing tensor and non-tensor fields
+
+ Returns:
+ Batched dict with padded tensors and collected non-tensor fields
+
+ Raises:
+ ValueError: If all samples do not have the same keys
+ """
+ if not batch:
+ return {}
+
+ # Verify all samples have the same keys
+ first_sample_keys = batch[0].keys()
+ for sample in batch:
+ if sample.keys() != first_sample_keys:
+ raise ValueError(
+ f"All samples must have the same keys. Expected {first_sample_keys}, got {sample.keys()}"
+ )
+
+ collated = {}
+
+ for key in first_sample_keys:
+ if isinstance(batch[0][key], torch.Tensor):
+ # Find max length for this tensor key
+ max_len = max(sample[key].size(0) for sample in batch)
+
+ # Determine padding value
+ pad_value = CROSS_ENTROPY_IGNORE_IDX if key == "labels" else 0
+
+ # Pad each sample to max_len
+ padded_tensors = []
+ for sample in batch:
+ seq_len = sample[key].size(0)
+ pad_len = max_len - seq_len
+ padded = F.pad(sample[key], (0, pad_len), value=pad_value)
+ padded_tensors.append(padded)
+
+ # Stack into batch
+ collated[key] = torch.stack(padded_tensors)
+ elif key == "metrics":
+ # Flatten metrics lists
+ collated[key] = []
+ for sample in batch:
+ collated[key].extend(sample[key])
+ else:
+ # Collect other non-tensor fields as lists
+ collated[key] = [sample[key] for sample in batch]
+
+ return collated
def collate_packed(
diff --git a/src/forge/data/dataset_metrics/__init__.py b/src/forge/data/dataset_metrics/__init__.py
deleted file mode 100644
index 3a218e282..000000000
--- a/src/forge/data/dataset_metrics/__init__.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-from .metric_agg_handlers import (
- AggregationHandler,
- CategoricalCountAggHandler,
- MaxAggHandler,
- MeanAggHandler,
- MetricState,
- MinAggHandler,
- StatsAggHandler,
- SumAggHandler,
-)
-from .metric_aggregator import MetricsAggregator
-from .metric_transform import (
- AggregationType,
- DefaultTrainingMetricTransform,
- Metric,
- MetricTransform,
-)
-
-__all__ = [
- "AggregationType",
- "AggregationHandler",
- "CategoricalCountAggHandler",
- "DefaultTrainingMetricTransform",
- "StatsAggHandler",
- "MaxAggHandler",
- "MeanAggHandler",
- "Metric",
- "MetricState",
- "MetricsAggregator",
- "MetricTransform",
- "MinAggHandler",
- "SumAggHandler",
-]
diff --git a/src/forge/data/dataset_metrics/metric_agg_handlers.py b/src/forge/data/dataset_metrics/metric_agg_handlers.py
deleted file mode 100644
index bb3978a6b..000000000
--- a/src/forge/data/dataset_metrics/metric_agg_handlers.py
+++ /dev/null
@@ -1,466 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-import logging
-from abc import ABC, abstractmethod
-from collections import Counter, deque
-from dataclasses import dataclass, field
-from typing import Any
-
-import torch
-
-from .metric_transform import AggregationType, Metric
-
-logger = logging.getLogger(__name__)
-
-
-@dataclass
-class MetricState:
- """Mutable state object representing the state of a (source, metric_name) on a single rank.
-
- Attributes:
- source (str): Name of the source, e.g. the dataset name. Used for logging and disambiguation.
- metric_name (str): Name of the metric.
- value (float): Current aggregated value, whose meaning depends on the aggregation type
- (e.g., running sum, current max).
- agg_type (AggregationType): Aggregation type.
- metadata (dict[str, Any]): Additional state like count, list of values, etc.
- """
-
- source: str
- metric_name: str
- value: float
- agg_type: AggregationType
- metadata: dict[str, Any] = field(default_factory=dict)
-
-
-class AggregationHandler(ABC):
- """Base class for handling metric aggregation.
-
- Each handler implements a specific aggregation strategy (SUM, MEAN, STATS, etc.)
- and manages the complete lifecycle: initialization, updates, local finalization,
- and distributed reduction. Handlers also handle serialization for checkpointing.
-
- The handler architecture allows pluggable aggregation strategies while maintaining
- consistent interfaces for the MetricsAggregator.
- """
-
- @abstractmethod
- def initialize_metric_state(
- self, source: str, metric_name: str, agg_type: AggregationType
- ) -> MetricState:
- """Create a new MetricState for a (source, metric_name) pair.
-
- Args:
- source (str): Name of the source, e.g. the dataset name. Used for logging and disambiguation.
- metric_name (str): Name of the metric.
- agg_type (AggregationType): Aggregation type.
-
- Returns:
- MetricState: New MetricState for this (source, metric_name) pair.
- """
- pass
-
- @abstractmethod
- def update(self, local_agg_metric: MetricState, metric: Metric) -> None:
- """Update cumulative MetricState with new metric info.
-
- Args:
- local_agg_metric (MetricState): State of the aggregation for this metric in the local rank.
- metric (Metric): Input metric info.
- """
- pass
-
- @abstractmethod
- def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]:
- """
- Computes the final value from the locally aggregated state. For example, for mean
- it would mean to divide the tracked sum by the tracked count.
-
- This method may expand a single metric into multiple, for instance,
- a list of numbers into mean, min, max, and percentiles.
-
- Args:
- local_agg_metric (MetricState): The locally aggregated metric state to finalize.
-
- Returns:
- list[MetricState]: List of finalized metric states.
- """
- pass
-
- @abstractmethod
- def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState:
- """
- Merge MetricStates from all ranks into final result. For example, for 'sum', it would mean to
- sum the values from all ranks.
-
- Args:
- local_agg_metrics (list[MetricState]): list of MetricStates from all ranks for a specific
- (source, metric_name) tuple after computing finalize_local_agg.
-
- Returns:
- MetricState: Final result for this (source, metric_name) pair.
- """
- pass
-
- def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
- """Convert handler-specific metadata to serializable format. Override this when using
- non-serializable types like deque or Counter. For example, convert deque to list, Counter to dict.
-
- Args:
- metadata (dict[str, Any]): AggHandler-specific metadata.
-
- Returns:
- dict[str, Any]: Serializable metadata.
- """
- return metadata.copy()
-
- def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
- """Restore handler-specific metadata from serialized format. Override this to reverse the
- serialize_metadata transformation. For example, convert list back to deque, dict back to Counter.
-
- Args:
- metadata (dict[str, Any]): AggHandler-specific metadata.
-
- Returns:
- dict[str, Any]: Deserialized metadata.
- """
- return metadata.copy()
-
-
-class SumAggHandler(AggregationHandler):
- """AggHandler for SUM aggregation. Initializes with 0.0 and accumulates metric values."""
-
- def initialize_metric_state(
- self, source: str, metric_name: str, agg_type: AggregationType
- ) -> MetricState:
- return MetricState(
- source=source,
- metric_name=metric_name,
- value=0.0,
- agg_type=agg_type,
- )
-
- def update(self, local_agg_metric: MetricState, metric: Metric) -> None:
- if not isinstance(metric.value, (int, float)):
- raise ValueError(
- f"SumAggHandler expects numeric values, got {type(metric.value)}"
- )
- local_agg_metric.value += metric.value
-
- def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]:
- return [local_agg_metric]
-
- def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState:
- if not local_agg_metrics:
- raise ValueError("Cannot aggregate empty list of metrics")
-
- total = sum(metric.value for metric in local_agg_metrics)
- return MetricState(
- source=local_agg_metrics[0].source,
- metric_name=local_agg_metrics[0].metric_name,
- value=total,
- agg_type=local_agg_metrics[0].agg_type,
- metadata=local_agg_metrics[0].metadata.copy(),
- )
-
-
-class MaxAggHandler(AggregationHandler):
- """AggHandler for MAX aggregation. Tracks maximum value across all updates."""
-
- def initialize_metric_state(
- self, source: str, metric_name: str, agg_type: AggregationType
- ) -> MetricState:
- return MetricState(
- source=source,
- metric_name=metric_name,
- value=float("-inf"),
- agg_type=agg_type,
- )
-
- def update(self, local_agg_metric: MetricState, metric: Metric) -> None:
- if not isinstance(metric.value, (int, float)):
- raise ValueError(
- f"MaxAggHandler expects numeric values, got {type(metric.value)}"
- )
- local_agg_metric.value = max(local_agg_metric.value, metric.value)
-
- def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]:
- return [local_agg_metric]
-
- def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState:
- max_value = max(r.value for r in local_agg_metrics)
- return MetricState(
- source=local_agg_metrics[0].source,
- metric_name=local_agg_metrics[0].metric_name,
- value=max_value,
- agg_type=local_agg_metrics[0].agg_type,
- metadata=local_agg_metrics[0].metadata.copy(),
- )
-
-
-class MinAggHandler(AggregationHandler):
- """AggHandler for MIN aggregation. Tracks minimum value across all updates."""
-
- def initialize_metric_state(
- self, source: str, metric_name: str, agg_type: AggregationType
- ) -> MetricState:
- return MetricState(
- source=source,
- metric_name=metric_name,
- value=float("inf"),
- agg_type=agg_type,
- )
-
- def update(self, local_agg_metric: MetricState, metric: Metric) -> None:
- if not isinstance(metric.value, (int, float)):
- raise ValueError(
- f"MinAggHandler expects numeric values, got {type(metric.value)}"
- )
- local_agg_metric.value = min(local_agg_metric.value, metric.value)
-
- def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]:
- return [local_agg_metric]
-
- def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState:
- min_value = min(r.value for r in local_agg_metrics)
- return MetricState(
- source=local_agg_metrics[0].source,
- metric_name=local_agg_metrics[0].metric_name,
- value=min_value,
- agg_type=local_agg_metrics[0].agg_type,
- metadata=local_agg_metrics[0].metadata.copy(),
- )
-
-
-class MeanAggHandler(AggregationHandler):
- """AggHandler for MEAN aggregation. Maintains running sum and count to compute average."""
-
- def initialize_metric_state(
- self, source: str, metric_name: str, agg_type: AggregationType
- ) -> MetricState:
- return MetricState(
- source=source,
- metric_name=metric_name,
- value=0.0,
- agg_type=agg_type,
- metadata={"sum": 0.0, "count": 0},
- )
-
- def update(self, local_agg_metric: MetricState, metric: Metric) -> None:
- local_agg_metric.metadata["sum"] += metric.value
- local_agg_metric.metadata["count"] += 1
-
- def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]:
- count = local_agg_metric.metadata["count"]
- local_agg_metric.value = (
- local_agg_metric.metadata["sum"] / count if count > 0 else 0.0
- )
- return [local_agg_metric]
-
- def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState:
- total_sum = sum(metric.metadata["sum"] for metric in local_agg_metrics)
- total_count = sum(metric.metadata["count"] for metric in local_agg_metrics)
-
- return MetricState(
- source=local_agg_metrics[0].source,
- metric_name=local_agg_metrics[0].metric_name,
- value=total_sum / total_count if total_count > 0 else 0.0,
- agg_type=local_agg_metrics[0].agg_type,
- metadata={"sum": total_sum, "count": total_count},
- )
-
-
-class StatsAggHandler(AggregationHandler):
- """AggHandler for STATS aggregation. Maintains a sliding window of values
- and expands into multiple statistical metrics (mean, min, max, percentiles, std).
-
- Note: Percentiles and standard deviation are approximated in distributed settings by averaging local
- percentiles and standard deviations across ranks. This is mathematically imprecise but provides a
- reasonable approximation for monitoring purposes.
-
- Args:
- window_size (int): Maximum number of recent values to retain for statistics.
-
- Raises:
- ValueError: If window_size is not positive.
- """
-
- def __init__(self, window_size: int = 1000):
- if window_size <= 0:
- raise ValueError(f"window_size must be positive, got {window_size}")
- self.window_size = window_size
-
- def initialize_metric_state(
- self, source: str, metric_name: str, agg_type: AggregationType
- ) -> MetricState:
- return MetricState(
- source=source,
- metric_name=metric_name,
- value=0.0,
- agg_type=agg_type,
- metadata={"values": deque(maxlen=self.window_size)},
- )
-
- def update(self, local_agg_metric: MetricState, metric: Metric) -> None:
- local_agg_metric.metadata["values"].append(metric.value)
-
- def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]:
- values = list(local_agg_metric.metadata["values"])
- if not values:
- return []
-
- values_tensor = torch.tensor(values, dtype=torch.float64)
- n = len(values_tensor)
-
- # Compute stats from the tensor
- sum_val = torch.sum(values_tensor).item()
- mean_val = sum_val / n
- min_val = torch.min(values_tensor).item()
- max_val = torch.max(values_tensor).item()
-
- # Compute percentiles
- percentile_definitions = torch.tensor([0.05, 0.5, 0.95], dtype=torch.float64)
- p05_val, p50_val, p95_val = torch.quantile(
- values_tensor, percentile_definitions
- ).tolist()
-
- # Return multiple MetricStates with proper agg_types for distributed reduction
- # NOTE: Percentiles use MEAN aggregation which approximates global percentiles
- # by averaging local percentiles.
- metrics = [
- MetricState(
- source=local_agg_metric.source,
- metric_name=f"{local_agg_metric.metric_name}_stat_mean",
- value=mean_val,
- agg_type=AggregationType.MEAN,
- metadata={"sum": sum_val, "count": n},
- ),
- MetricState(
- source=local_agg_metric.source,
- metric_name=f"{local_agg_metric.metric_name}_stat_min",
- value=min_val,
- agg_type=AggregationType.MIN,
- metadata={},
- ),
- MetricState(
- source=local_agg_metric.source,
- metric_name=f"{local_agg_metric.metric_name}_stat_max",
- value=max_val,
- agg_type=AggregationType.MAX,
- metadata={},
- ),
- MetricState(
- source=local_agg_metric.source,
- metric_name=f"{local_agg_metric.metric_name}_stat_p05",
- value=p05_val,
- agg_type=AggregationType.MEAN,
- metadata={"sum": p05_val, "count": 1},
- ),
- MetricState(
- source=local_agg_metric.source,
- metric_name=f"{local_agg_metric.metric_name}_stat_p50",
- value=p50_val,
- agg_type=AggregationType.MEAN,
- metadata={"sum": p50_val, "count": 1},
- ),
- MetricState(
- source=local_agg_metric.source,
- metric_name=f"{local_agg_metric.metric_name}_stat_p95",
- value=p95_val,
- agg_type=AggregationType.MEAN,
- metadata={"sum": p95_val, "count": 1},
- ),
- ]
-
- # Standard deviation is only well-defined for n > 1
- if n > 1:
- std_val = torch.std(values_tensor).item()
- metrics.append(
- MetricState(
- source=local_agg_metric.source,
- metric_name=f"{local_agg_metric.metric_name}_stat_std",
- value=std_val,
- agg_type=AggregationType.MEAN,
- metadata={"sum": std_val, "count": 1},
- )
- )
- return metrics
-
- def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState:
- raise NotImplementedError(
- "Metrics with AggregationType.STATS were converted to other "
- "AggregationTypes for distributed reduction. finalize_dist_agg should not be called."
- )
-
- def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
- """Convert deque to list for serialization."""
- serialized = metadata.copy()
- if "values" in serialized:
- serialized["values"] = list(serialized["values"])
- return serialized
-
- def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
- """Convert list back to deque."""
- deserialized = metadata.copy()
- if "values" in deserialized:
- deserialized["values"] = deque(
- deserialized["values"], maxlen=self.window_size
- )
- return deserialized
-
-
-class CategoricalCountAggHandler(AggregationHandler):
- """AggHandler for CATEGORICAL_COUNT aggregation. Counts occurrences of categorical values
- and expands into individual count metrics for each category."""
-
- def initialize_metric_state(
- self, source: str, metric_name: str, agg_type: AggregationType
- ) -> MetricState:
- return MetricState(
- source=source,
- metric_name=metric_name,
- value=0.0,
- agg_type=agg_type,
- metadata={"counts": Counter()},
- )
-
- def update(self, local_agg_metric: MetricState, metric: Metric) -> None:
- local_agg_metric.metadata["counts"][metric.value] += 1
-
- def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]:
- # Expand categorical counts into individual metrics
- results = []
- for category, count in local_agg_metric.metadata["counts"].items():
- results.append(
- MetricState(
- source=local_agg_metric.source,
- metric_name=f"{local_agg_metric.metric_name}_count_{category}",
- value=count,
- agg_type=AggregationType.SUM,
- )
- )
- return results
-
- def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState:
- raise NotImplementedError(
- "Metrics with AggregationType.CATEGORICAL_COUNT were converted to other "
- "AggregationType.SUM for distributed reduction. finalize_dist_agg should not be called."
- )
-
- def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
- """Convert Counter to dict for serialization."""
- serialized = metadata.copy()
- if "counts" in serialized:
- serialized["counts"] = dict(serialized["counts"])
- return serialized
-
- def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
- """Convert dict back to Counter."""
- deserialized = metadata.copy()
- if "counts" in deserialized:
- deserialized["counts"] = Counter(deserialized["counts"])
- return deserialized
diff --git a/src/forge/data/dataset_metrics/metric_aggregator.py b/src/forge/data/dataset_metrics/metric_aggregator.py
deleted file mode 100644
index 40d8075ce..000000000
--- a/src/forge/data/dataset_metrics/metric_aggregator.py
+++ /dev/null
@@ -1,344 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-import ast
-import logging
-from collections import defaultdict
-from typing import Any, Union
-
-import torch.distributed as dist
-
-from .metric_agg_handlers import (
- AggregationHandler,
- CategoricalCountAggHandler,
- MaxAggHandler,
- MeanAggHandler,
- MetricState,
- MinAggHandler,
- StatsAggHandler,
- SumAggHandler,
-)
-from .metric_transform import AggregationType, Metric
-
-logger = logging.getLogger(__name__)
-
-
-class MetricsAggregator:
- """Aggregates metrics across datasets and distributed ranks using pluggable handlers.
-
- This class uses a handler-based strategy, where each aggregation type (SUM, MEAN, etc.)
- has a corresponding AggregationHandler. It maintains a single state object for each
- (source, metric_name) pair.
-
- Internal State Visualization:
- {
- ("alpaca", "tokens_seen"): MetricState(value=200.0, agg_type=SUM, ...),
- ("alpaca", "avg_loss"): MetricState(value=0.01, agg_type=MEAN, metadata={'sum': ..., 'count': ...}),
- ("slim_orca", "seq_len"): MetricState(agg_type=STATS, metadata={'values': deque([...])}),
- }
-
- When preparing metrics for logging, the aggregator follows a two-phase process:
- 1. Local Aggregation: Each rank aggregates its metrics independently
- 2. Distributed Reduction: If in distributed mode, results are combined across ranks
-
- The aggregator's state is checkpointable, allowing training resumption.
-
- Args:
- dist_window_size (int): Window size for StatsAggHandler tracking.
-
- Example:
- >>> from forge.data.metrics import MetricsAggregator, Metric, AggregationType
- >>>
- >>> aggregator = MetricsAggregator()
- >>>
- >>> # Sample metrics from different batches
- >>> batch1_metrics = [
- ... Metric("alpaca", "tokens_seen", 100, AggregationType.SUM),
- ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN),
- ... ]
- >>>
- >>> batch2_metrics = [
- ... Metric("alpaca", "tokens_seen", 100, AggregationType.SUM),
- ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN),
- ... ]
- >>>
- >>> # Update with metrics
- >>> aggregator.update(batch1_metrics)
- >>> aggregator.update(batch2_metrics)
- >>>
- >>> # Get final results
- >>> results = aggregator.get_metrics_for_logging(prefix="train")
- >>> # {"train_alpaca/tokens_seen": 200.0, "train_alpaca/avg_tokens_seen": 100.0}
-
- Raises:
- ValueError: If dist_window_size is not positive.
- """
-
- def __init__(self, dist_window_size: int = 1000):
- if dist_window_size <= 0:
- raise ValueError(
- f"dist_window_size must be positive, got {dist_window_size}"
- )
-
- # Storage: {(source, metric_name): MetricState} - O(unique metrics) not O(samples)
- self._metric_states: dict[tuple[str, str], MetricState] = {}
- self._dist_window_size = dist_window_size
-
- # Track aggregation types for validation - prevents same metric name with different agg types
- self._metric_agg_types: dict[tuple[str, str], AggregationType] = {}
-
- # Create handler registry - all handlers initialized upfront
- self._handlers: dict[AggregationType, AggregationHandler] = {
- AggregationType.SUM: SumAggHandler(),
- AggregationType.MAX: MaxAggHandler(),
- AggregationType.MIN: MinAggHandler(),
- AggregationType.MEAN: MeanAggHandler(),
- AggregationType.STATS: StatsAggHandler(dist_window_size),
- AggregationType.CATEGORICAL_COUNT: CategoricalCountAggHandler(),
- }
-
- def _validate_metric_consistency(self, metric: Union[Metric, MetricState]) -> None:
- """Validate that metric name uses consistent aggregation type."""
- metric_key = (metric.source, metric.metric_name)
- metric_name = metric.metric_name
-
- if metric_key in self._metric_agg_types:
- existing_agg_type = self._metric_agg_types[metric_key]
- if existing_agg_type != metric.agg_type:
- raise ValueError(
- f"Metric '{metric_name}' in dataset '{metric.source}' "
- f"is already registered with aggregation type {existing_agg_type.value}, "
- f"but a handler or user code tried to use it with type {metric.agg_type.value}. "
- f"Use different metric names for different aggregation types."
- )
- else:
- # Track this metric's aggregation type
- self._metric_agg_types[metric_key] = metric.agg_type
-
- def register_handler(
- self, agg_type: AggregationType, handler: AggregationHandler
- ) -> None:
- """Register custom aggregation handler for specified type.
-
- Args:
- agg_type (AggregationType): The aggregation type to handle
- handler (AggregationHandler): Handler instance implementing the AggregationHandler interface
- """
- # Warn if replacing a handler that's already in use
- if agg_type in self._handlers and any(
- state.agg_type == agg_type for state in self._metric_states.values()
- ):
- logger.warning(
- f"Replacing handler for {agg_type} - aggregation type already in use by existing metrics. "
- f"This may affect existing metric behavior."
- )
-
- self._handlers[agg_type] = handler
-
- def update(self, metrics: list[Metric]) -> None:
- """Update (source, metric_name) metric state with new values.
-
- Args:
- metrics (list[Metric]): List of metrics to update the state with
-
- Raises:
- ValueError: If no handler is registered for a metric's aggregation type,
- or if metric name conflicts with existing aggregation type.
- """
- for metric in metrics:
- # Same metric name must use same aggregation type
- self._validate_metric_consistency(metric)
-
- metric_key = (metric.source, metric.metric_name)
- handler = self._handlers.get(metric.agg_type)
-
- if handler is None:
- raise ValueError(
- f"No handler registered for aggregation type: {metric.agg_type}"
- )
-
- if metric_key not in self._metric_states:
- self._metric_states[metric_key] = handler.initialize_metric_state(
- metric.source, metric.metric_name, metric.agg_type
- )
-
- local_agg_metric = self._metric_states[metric_key]
- handler.update(local_agg_metric, metric) # Mutates local_agg_metric
-
- def get_metrics_for_logging(self, prefix: str = "data") -> dict[str, float]:
- """Get final metrics for logging in standard format.
-
- Args:
- prefix (str): Prefix for metric names in the returned dictionary
-
- Returns:
- dict[str, float]: Dictionary with keys like "{prefix}_{source}/{metric_name}"
- and float values. For example, with `prefix="train"`, `source="alpaca"`,
- `metric_name="loss"`, the key would be `train_alpaca/loss`.
- """
- final_results = self._compute_unified_metrics()
-
- return {
- f"{prefix}_{result.source}/{result.metric_name}": result.value
- for result in final_results
- }
-
- def _compute_unified_metrics(self) -> list[MetricState]:
- """
- Compute metrics handling both local and distributed cases uniformly.
-
- Returns:
- list[MetricState]: Final results ready for logging
- """
- # Step 1: Get local results from all handlers (may expand stats/categoricals)
- prepared_results = []
- for local_agg_metric in self._metric_states.values():
- handler = self._handlers[local_agg_metric.agg_type]
- generated_metrics = handler.finalize_local_agg(local_agg_metric)
-
- # Validate each newly generated metric state immediately
- for gen_metric in generated_metrics:
- self._validate_metric_consistency(gen_metric)
-
- prepared_results.extend(generated_metrics)
-
- # Step 2: Apply distributed reduction if needed
- if dist.is_initialized() and dist.get_world_size() > 1:
- prepared_results = self._finalize_dist_agg(prepared_results)
-
- return prepared_results
-
- def _finalize_dist_agg(
- self, local_agg_metrics: list[MetricState]
- ) -> list[MetricState]:
- """Apply distributed reduction to local results.
-
- Args:
- local_agg_metrics (list[MetricState]): (source, metric_name) metric pairs from this rank
-
- Returns:
- list[MetricState]: Reduced results combining all ranks
- """
- world_size = dist.get_world_size()
-
- # Gather all results from all ranks
- all_results = [None] * world_size
- dist.all_gather_object(all_results, local_agg_metrics)
-
- # Group by (source, metric_name) for reduction
- grouped = defaultdict(list)
- for rank_results in all_results:
- if rank_results: # Handle ranks with no metrics
- for result in rank_results:
- result_key = (result.source, result.metric_name)
- grouped[result_key].append(result)
-
- # Apply handler-specific distributed reduction
- reduced_results = []
- for result_key, results_list in grouped.items():
- if not results_list:
- continue # Skip empty groups
-
- # All results for a key should have same agg_type
- agg_type = results_list[0].agg_type
- handler = self._handlers[agg_type]
- reduced_result = handler.finalize_dist_agg(results_list)
- reduced_results.append(reduced_result)
-
- return reduced_results
-
- def state_dict(self) -> dict[str, Any]:
- """Serialize aggregator state for checkpointing.
-
- Returns:
- dict[str, Any]: Serializable dictionary containing all aggregator state
- """
- serializable_state = {}
- required_agg_types = set() # Track aggregation types used in saved states
-
- for metric_key, local_agg_metric in self._metric_states.items():
- # Get handler for this result's aggregation type
- handler = self._handlers[local_agg_metric.agg_type]
- required_agg_types.add(local_agg_metric.agg_type)
-
- # Convert MetricState to serializable dict
- result_dict = {
- "source": local_agg_metric.source,
- "metric_name": local_agg_metric.metric_name,
- "value": local_agg_metric.value,
- "agg_type": local_agg_metric.agg_type,
- "metadata": handler.serialize_metadata(local_agg_metric.metadata),
- }
-
- # Convert tuple key to string for JSON compatibility
- serializable_state[str(metric_key)] = result_dict
-
- return {
- "state": serializable_state,
- "dist_window_size": self._dist_window_size,
- "required_agg_types": list(
- required_agg_types
- ), # Save which handlers are needed
- # Save which aggregation types are used for each metric
- "metric_agg_types": {
- str(k): v.value for k, v in self._metric_agg_types.items()
- },
- }
-
- def load_state_dict(self, state_dict: dict[str, Any]) -> None:
- """Load aggregator state from checkpoint.
-
- Args:
- state_dict (dict[str, Any]): Dictionary containing serialized aggregator state
-
- Raises:
- ValueError: If required handlers are missing after checkpoint restore
- """
- self._dist_window_size = state_dict.get("dist_window_size", 1000)
-
- # Sanity check: Ensure all required handlers are available
- required_agg_types = state_dict.get("required_agg_types", [])
- missing_handlers = []
- for agg_type in required_agg_types:
- if agg_type not in self._handlers:
- missing_handlers.append(agg_type)
-
- if missing_handlers:
- raise ValueError(
- f"Missing handlers for aggregation types: {missing_handlers}. "
- f"Custom handlers must be re-registered before checkpoint restore."
- )
-
- deserialized_state = {}
- for key_str, result_dict in state_dict["state"].items():
- # Convert string keys back to tuples
- metric_key = ast.literal_eval(key_str)
-
- # Get handler for this aggregation type
- agg_type = result_dict["agg_type"]
- handler = self._handlers[agg_type]
-
- # Restore metadata using handler-specific deserialization
- metadata = handler.deserialize_metadata(result_dict["metadata"])
-
- # Create MetricState from dict
- local_agg_metric = MetricState(
- source=result_dict["source"],
- metric_name=result_dict["metric_name"],
- value=result_dict["value"],
- agg_type=result_dict["agg_type"],
- metadata=metadata,
- )
-
- deserialized_state[metric_key] = local_agg_metric
-
- self._metric_states = deserialized_state
-
- # Restore validation state
- self._metric_agg_types = {}
- for key_str, agg_type_str in state_dict.get("metric_agg_types", {}).items():
- key = ast.literal_eval(key_str)
- self._metric_agg_types[key] = AggregationType(agg_type_str)
diff --git a/src/forge/data/dataset_metrics/metric_transform.py b/src/forge/data/dataset_metrics/metric_transform.py
deleted file mode 100644
index 2898c8e43..000000000
--- a/src/forge/data/dataset_metrics/metric_transform.py
+++ /dev/null
@@ -1,150 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-from abc import ABC, abstractmethod
-from dataclasses import dataclass
-from enum import Enum
-from typing import Any, Union
-
-from forge.interfaces import Transform
-
-
-@dataclass(frozen=True)
-class Metric:
- source: str
- metric_name: str
- value: Union[int, float, str]
- agg_type: "AggregationType"
-
-
-class AggregationType(Enum):
- """Defines how a metric's value should be aggregated by the MetricsAggregator.
-
- Each type corresponds to a specific AggregationHandler that implements the logic
- for initialization, updates, and distributed reduction.
- """
-
- SUM = "sum"
- MEAN = "mean"
- STATS = "distribution"
- CATEGORICAL_COUNT = "categorical_count"
- MAX = "max"
- MIN = "min"
-
-
-class MetricTransform(Transform, ABC):
- """Applied to each dataset sample to generate per-sample metrics for training tracking.
-
- Creates Metric objects that are later aggregated by MetricsAggregator. This separation
- of concerns ensures metrics are correctly aggregated even with multiple dataloader
- workers and in distributed settings.
-
- The transform must be configured with a source via set_source() before use.
- Each call to __call__ adds metrics to the sample's "metrics" key.
-
- Example:
- >>> transform = DefaultTrainingMetricTransform()
- >>> transform.set_source("alpaca")
- >>> sample = {"tokens": [1, 2, 3]}
- >>> result = transform(sample)
- >>> # result["metrics"] contains list of Metric objects
- """
-
- def set_source(self, source: str) -> None:
- """Called by the dataset to set the namespace for metrics.
-
- This is used to differentiate metrics from multiple datasets, for example,
- "alpaca/tokens_seen" vs. "slim_orca/tokens_seen".
-
- Args:
- source (str): Name of the dataset, used for logging and disambiguation.
- """
- self.source = source
-
- @abstractmethod
- def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]:
- """Generate metrics for a single sample.
-
- Args:
- sample (dict[str, Any]): The sample dictionary to generate metrics from
-
- Returns:
- list[Metric]: List of metrics generated for this sample
-
- Raises:
- NotImplementedError: If subclass does not implement this method.
- """
- raise NotImplementedError("Subclasses must implement _generate_metrics method")
-
- def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
- if not hasattr(self, "source"):
- raise RuntimeError(
- "'transform.set_source' must be called before using the transform."
- )
-
- # Generate metrics for this sample
- metrics = self._generate_metrics(sample)
-
- # Add to existing metrics list or create new one
- if "metrics" not in sample:
- sample["metrics"] = []
- sample["metrics"].extend(metrics)
- return sample
-
-
-class DefaultTrainingMetricTransform(MetricTransform):
- """Generates common training metrics: samples seen, tokens seen, and sequence length.
-
- This transform detects the token key in a sample, checking for "tokens"
- first and then falling back to "input_ids".
-
- For details on the base class behavior, see MetricTransform.
-
- Tracked metrics:
- - samples_seen: Cumulative count of samples processed (SUM aggregation)
- - tokens_seen: Cumulative sum of all tokens processed (SUM aggregation)
- - seq_len: Distribution stats of sequence lengths (STATS aggregation)
-
- Example:
- >>> transform = DefaultTrainingMetricTransform()
- >>> transform.set_source("alpaca")
- >>>
- >>> sample = {"tokens": [1, 2, 3, 4, 5]} # 5 tokens
- >>> metrics = transform._generate_metrics(sample)
- >>> # This generates the following Metric objects:
- >>> # [
- >>> # Metric(source="alpaca", metric_name="samples_seen", value=1, agg_type=AggregationType.SUM),
- >>> # Metric(source="alpaca", metric_name="tokens_seen", value=5, agg_type=AggregationType.SUM),
- >>> # Metric(source="alpaca", metric_name="seq_len", value=5, agg_type=AggregationType.STATS)
- >>> # ]
- """
-
- def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]:
- # Determine token key
- token_key = "tokens" if "tokens" in sample else "input_ids"
- token_len = len(sample.get(token_key, []))
-
- # Create metrics for this sample
- return [
- Metric(
- source=self.source,
- metric_name="samples_seen",
- value=1,
- agg_type=AggregationType.SUM,
- ),
- Metric(
- source=self.source,
- metric_name="tokens_seen",
- value=token_len,
- agg_type=AggregationType.SUM,
- ),
- Metric(
- source=self.source,
- metric_name="seq_len",
- value=token_len,
- agg_type=AggregationType.STATS,
- ),
- ]
diff --git a/src/forge/data/dataset_metrics/readme.md b/src/forge/data/dataset_metrics/readme.md
deleted file mode 100644
index 76ec424e3..000000000
--- a/src/forge/data/dataset_metrics/readme.md
+++ /dev/null
@@ -1,176 +0,0 @@
-# forge Metrics Module
-
-## Overview
-
-The metrics module provides a robust system for tracking and aggregating training metrics across multiple datasets and distributed environments. It follows a **strategy pattern** design with pluggable aggregation handlers to efficiently handle different types of metrics.
-
-## Architecture Overview
-
-```
-┌────────────────────────────────────────────────────┐
-│ Training Loop │
-└─────────────────────┬──────────────────────────────┘
- │
-┌─────────────────────▼──────────────────────────────┐
-│ MetricTransform │
-│ • Applied to each sample │
-│ • Generates per-sample metrics │
-│ • Examples: tokens_seen, seq_len, samples_seen │
-└─────────────────────┬──────────────────────────────┘
- │ list[Metric]
-┌─────────────────────▼──────────────────────────────┐
-│ MetricsAggregator │
-│ • Aggregates metrics across samples and ranks │
-│ • Uses pluggable AggregationHandlers │
-│ • Handles distributed reduction │
-└─────────────────────┬──────────────────────────────┘
- │ {prefix}_{source}/{metric_name} # prefix is "train", "val", etc.
-┌─────────────────────▼──────────────────────────────┐
-│ Logging System │
-│ • W&B, TensorBoard, etc. │
-│ • Gets formatted metrics ready for logging │
-└────────────────────────────────────────────────────┘
-```
-
-## File Structure
-
-- **`metric_transform.py`**: Defines `Metric`, `AggregationType`, and transform classes
-- **`metric_agg_handlers.py`**: Aggregation strategy implementations
-- **`metric_aggregator.py`**: Main aggregator orchestrating the handlers
-
-## Customizing metrics
-
-- **Custom transforms**: Extend `MetricTransform` for domain-specific metrics
-- **Handler registration**: Register custom handlers for specialized aggregation needs
-
-#######
-## TODO
-## Move this from here to website docs
-#######
-
-## Core Components
-
-### 1. MetricTransform
-Generates per-sample metrics during data processing.
-
-**Key Features:**
-- Applied to each sample in the dataset
-- Creates `Metric` objects with dataset name, metric name, value, and aggregation type
-- Handles dataset namespacing for multi-dataset scenarios
-
-**Example Usage:**
-```python
-from forge.data.metrics import DefaultTrainingMetricTransform, AggregationType
-
-transform = DefaultTrainingMetricTransform()
-transform.set_source("alpaca")
-
-# Applied to each sample
-sample = {"tokens": [1, 2, 3, 4, 5]}
-sample = transform(sample)
-# sample["metrics"] now contains:
-# [
-# Metric(source="alpaca", name="samples_seen", value=1, agg_type=AggregationType.SUM),
-# Metric(source="alpaca", name="tokens_seen", value=5, agg_type=AggregationType.SUM),
-# Metric(source="alpaca", name="seq_len", value=5, agg_type=AggregationType.STATS)
-# ]
-```
-
-### 2. MetricsAggregator
-Efficiently aggregates metrics across samples and distributed ranks.
-
-**Key Features:**
-- Handler-based strategy pattern for different aggregation types
-- Distributed-aware with automatic rank reduction
-- Checkpointable state for training resumption
-- Keep track of (metric, dataset) pairs
-
-**Aggregation Types (at the time of writing):**
-- `SUM`: Cumulative totals (e.g., total tokens processed)
-- `MEAN`: Running averages (e.g., average loss)
-- `MAX/MIN`: Extrema tracking (e.g., max sequence length seen)
-- `STATS`: Statistical summaries (mean, min, max, percentiles)
-- `CATEGORICAL_COUNT`: Category cumulative counts (e.g. num of samples from a given category)
-
-**Example Usage:**
-```python
-from forge.data.metrics import MetricsAggregator, Metric, AggregationType
-
-# Create aggregator
-aggregator = MetricsAggregator()
-
-# Sample metrics from different batches
-batch1_metrics = [
- Metric("alpaca", "tokens_seen", 100, AggregationType.SUM),
- Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN),
-]
-
-batch2_metrics = [
- Metric("alpaca", "tokens_seen", 100, AggregationType.SUM),
- Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN),
-]
-
-# Update with metrics
-aggregator.update(batch1_metrics)
-aggregator.update(batch2_metrics)
-
-# Get final results
-results = aggregator.get_metrics_for_logging(prefix="train")
-# {"train_alpaca/tokens_seen": 200.0, "train_alpaca/avg_tokens_seen": 100.0}
-```
-
-### 3. AggregationHandlers
-Pluggable strategies for different aggregation patterns.
-
-```
-AggregationHandler (ABC)
-├── SumAggHandler # value += metric.value
-├── MeanAggHandler # tracks sum and count
-├── MaxAggHandler # value = max(value, metric.value)
-├── MinAggHandler # value = min(value, metric.value)
-├── StatsAggHandler # maintains value window + stats
-└── CategoricalCountAggHandler # Counter for categories
-```
-
-**Custom Handler Example:**
-```python
-class CustomAggHandler(AggregationHandler):
- def initialize_metric_state(self, source, metric_name, agg_type):
- return MetricState(
- source=source,
- metric_name=metric_name,
- value=, # should change
- agg_type=agg_type,
- metadata={} # may need to change
- )
-
- def update(self, local_agg_metric, metric):
- ...
-
- def finalize_local_agg(self, local_agg_metric):
- ...
-
- def finalize_dist_agg(self, local_agg_metrics):
- ...
-
-# Register with aggregator
-aggregator.register_handler(AggregationType.CUSTOM, CustomAggHandler())
-```
-
-## Distributed Training Support
-
-The metrics system automatically handles distributed environments:
-
-1. **Local Aggregation**: Each rank aggregates its own metrics
-2. **Distributed Reduction**: Results are combined across ranks using `all_gather_object`
-3. **Type-Aware Reduction**: Each aggregation type uses appropriate reduction (sum, mean, max, etc.)
-
-**Distributed Flow:**
-```
-Rank 0: [(ds1, metric1), (ds1, metric2)] → LocalAgg → [(ds1, metric1), (ds1, metric2)]
-Rank 1: [(ds1, metric1), (ds1, metric2)] → LocalAgg → [(ds1, metric1), (ds1, metric2)]
- ↓
- AllGather + Reduce
- ↓
- Final Results [(ds1, metric1), (ds1, metric2)]
-```
diff --git a/src/forge/data/datasets/dataset.py b/src/forge/data/datasets/dataset.py
index 57a624c67..f18d9e07e 100644
--- a/src/forge/data/datasets/dataset.py
+++ b/src/forge/data/datasets/dataset.py
@@ -61,7 +61,7 @@ class DatasetInfo:
class TuneIterableDataset(IterableDataset, ABC):
- """Base class for all torchtune iterable datasets.
+ """Base class for all forge iterable datasets.
Datasets are composable, enabling complex structures such as:
``PackedDataset(InterleavedDataset([InterleavedDataset([ds1, ds2]), ds3]))``
diff --git a/src/forge/data/datasets/hf_dataset.py b/src/forge/data/datasets/hf_dataset.py
index 56b17712b..75b484607 100644
--- a/src/forge/data/datasets/hf_dataset.py
+++ b/src/forge/data/datasets/hf_dataset.py
@@ -5,20 +5,15 @@
# LICENSE file in the root directory of this source tree.
import logging
-from typing import Any, Callable, Iterator, Optional
+from typing import Any, Callable, Iterator
import torch
import torch.distributed as dist
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
-from forge.data.dataset_metrics import (
- AggregationType,
- DefaultTrainingMetricTransform,
- Metric,
- MetricTransform,
-)
-from forge.interfaces import Transform
+from forge.data.metric_transform import DefaultDatasetMetricTransform, MetricTransform
+from forge.observability.metrics import Metric, Reduce
from .dataset import DatasetInfo, InfiniteTuneIterableDataset
@@ -37,26 +32,26 @@ class HfIterableDataset(InfiniteTuneIterableDataset):
- Returning an infinite iterator over the dataset
Args:
- message_transform (Optional[Transform]): Transforms raw data into a `Message`.
- model_transform (Optional[Transform]): Prepares messages for the model,
+ message_transform (Callable | None): Transforms raw data into a `Message`.
+ model_transform (Callable | None): Prepares messages for the model,
usually by tokenizing them.
- output_transform (Optional[Transform]): Prepares tokenized inputs for the
+ output_transform (Callable | None): Prepares tokenized inputs for the
recipe, often by manipulating labels (e.g., setting an ignore index).
This transform is recipe-dependent (e.g., SFT, DPO, etc.).
- metric_transform (Optional[MetricTransform]): Computes metrics from a
+ metric_transform (MetricTransform | None): Computes metrics from a
sample (e.g., token count). If ``None``, a default transform is used.
To disable standard metric tracking, set this to ``lambda x: x``.
- shuffle_buffer_size (Optional[int]): Size of the shuffle buffer.
+ shuffle_buffer_size (int | None): Size of the shuffle buffer.
If ``None`` or 0, no shuffling is performed.
- weight (Optional[float]): Weight for this dataset. Defaults to 1.0.
+ weight (float | None): Weight for this dataset. Defaults to 1.0.
seed (int): Seed for shuffling.
num_shards_per_rank (int): The target number of shards per worker (GPU).
The actual number of shards will be a multiple of
``world_size * dataloader_workers``.
- dataset_name (Optional[str]): Name of the dataset. If ``None``, a name is
+ dataset_name (str | None): Name of the dataset. If ``None``, a name is
generated from the ``path``, ``source``, and ``split``.
- filter_fn (Optional[Callable]): A function to filter the dataset.
- filter_kwargs (Optional[dict[str, Any]]): Keyword arguments for ``filter_fn``.
+ filter_fn (Callable | None): A function to filter the dataset.
+ filter_kwargs (dict[str, Any] | None): Keyword arguments for ``filter_fn``.
**load_dataset_kwargs: Keyword arguments for the
:func:`~datasets.load_dataset` function.
"""
@@ -64,17 +59,18 @@ class HfIterableDataset(InfiniteTuneIterableDataset):
def __init__(
self,
*,
- message_transform: Optional[Transform] = None,
- model_transform: Optional[Transform] = None,
- output_transform: Optional[Transform] = None,
- metric_transform: Optional[MetricTransform] = None,
- shuffle_buffer_size: Optional[int] = 1000,
- weight: Optional[float] = 1.0,
+ message_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
+ model_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
+ output_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
+ metric_transform: MetricTransform | None = None,
+ shuffle_buffer_size: int | None = 1000,
+ weight: float | None = 1.0,
seed: int = 42,
num_shards_per_rank: int = 64,
- dataset_name: Optional[str] = None,
- filter_fn: Optional[Callable] = None,
- filter_kwargs: Optional[dict[str, Any]] = None,
+ dataset_name: str | None = None,
+ filter_fn: Callable | None = None,
+ filter_kwargs: dict[str, Any] | None = None,
+ dp_mesh: dist.ProcessGroup | None = None,
**load_dataset_kwargs,
):
# Store configuration
@@ -84,9 +80,10 @@ def __init__(
self._model_transform = model_transform
self._output_transform = output_transform
self._weight = weight if weight is not None else 1.0
+ self._dp_mesh = dp_mesh
# Create default transform if not provided
- self._metric_transform = metric_transform or DefaultTrainingMetricTransform()
+ self._metric_transform = metric_transform or DefaultDatasetMetricTransform()
# Auto-generate dataset name if not provided
if dataset_name is None:
@@ -107,6 +104,10 @@ def __init__(
self._metric_transform.set_source(dataset_name)
# Internal state for resumption
+ # _start_epoch: The epoch to start from. Updated on resume from ckpt.
+ # useful when doing iter(ds), which restarts dataset from original state.
+ self._start_epoch = 0
+ # _num_epochs: updated on every dataset exhaustion
self._num_epochs = 0
# Load and setup HF dataset
@@ -135,20 +136,33 @@ def _setup_hf_dataset(
self,
load_dataset_kwargs: dict[str, Any],
num_shards_per_rank: int,
- filter_fn: Optional[Callable] = None,
- filter_kwargs: Optional[dict[str, Any]] = None,
+ filter_fn: Callable | None = None,
+ filter_kwargs: dict[str, Any] | None = None,
):
"""
One-time setup of HuggingFace dataset that handles Handles distributed sharding,
shuffle configuration, and filtering. Called once during __init__.
"""
- # Distributed setup
+ # Extract rank/world_size from DP mesh
world_size, rank = 1, 0
- if dist.is_initialized():
+ if self._dp_mesh is not None:
+ world_size = dist.get_world_size(group=self._dp_mesh)
+ rank = dist.get_rank(group=self._dp_mesh)
+ logger.debug(
+ f"Using DP mesh for sharding: rank={rank}, world_size={world_size}"
+ )
+ elif dist.is_initialized():
+ # Fallback to global rank (may not respect TP/PP)
world_size = dist.get_world_size()
rank = dist.get_rank()
+ # TODO: is there a way to detect this and raise error instead?
+ logger.warning(
+ f"Using global rank for sharding: rank={rank}, world_size={world_size}. "
+ f"If using other types of parallelsim (CP/TP/PP), pass dp_mesh for correct sharding."
+ )
+
# Load and shard dataset
ds = load_dataset(**load_dataset_kwargs)
@@ -157,7 +171,6 @@ def _setup_hf_dataset(
if is_streaming:
logger.warning(
f"Streaming datasets were not yet tested for distributed training. "
- f"split_dataset_by_node is applied, but no resharding was done manually. "
f"Dataset '{self.info.name}' has "
f"{getattr(ds, 'num_shards', 'unknown')} shards, and your training has {world_size} ranks."
f"See: https://huggingface.co/docs/datasets/en/package_reference/main_classes?#datasets.IterableDataset.shard"
@@ -192,7 +205,7 @@ def _setup_hf_dataset(
if num_shards > dataset_size:
raise ValueError(
f"Number of shards ({num_shards}) is greater than the dataset size ({dataset_size})."
- f"Please decrease one of {num_shards_per_rank=} or {num_dataloader_workers=} or {world_size=}."
+ f"Please decrease one of {num_shards_per_rank=} or dataloader.num_workers={num_dataloader_workers}"
)
ds = ds.to_iterable_dataset(num_shards=num_shards)
@@ -223,6 +236,7 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
- Adds 'num_epochs' metric to track dataset progress
- Yields samples indefinitely for continuous training
"""
+ self._num_epochs = self._start_epoch
while True: # Infinite iteration
self._ds.set_epoch(self._num_epochs)
@@ -240,15 +254,16 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
# Track the number of epochs completed for each dataset. This is
# especially useful when interleaving multiple datasets, but
# also necessary to track dataset-level metrics.
- metric_num_epochs = Metric(
- source=self.info.name,
- metric_name="num_epochs",
- value=self._num_epochs,
- agg_type=AggregationType.MAX,
- )
if "metrics" not in sample:
sample["metrics"] = []
- sample["metrics"].append(metric_num_epochs)
+
+ sample["metrics"].append(
+ Metric(
+ key=f"dataset/{self.info.name}/num_epochs",
+ value=self._num_epochs,
+ reduction=Reduce.MAX,
+ )
+ )
samples_yielded += 1
yield sample
@@ -280,7 +295,7 @@ def state_dict(self) -> dict[str, Any]:
return state
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
- self._num_epochs = state_dict["num_epochs"]
+ self._start_epoch = state_dict["num_epochs"]
hf_state = state_dict["hf_dataset_state"]
# HF is responsible for resuming the dataset state
diff --git a/src/forge/data/datasets/packed.py b/src/forge/data/datasets/packed.py
index 105921acb..e7e8f03c1 100644
--- a/src/forge/data/datasets/packed.py
+++ b/src/forge/data/datasets/packed.py
@@ -7,17 +7,17 @@
import logging
from abc import ABC, abstractmethod
from collections import deque
-from typing import Any, Generic, Iterable, Iterator, Optional, TypeVar
+from typing import Any, Generic, Iterable, Iterator, TypeVar
import torch
+
+from forge.data import CROSS_ENTROPY_IGNORE_IDX
+from forge.observability.metrics import Metric, Reduce
from torch.nn.attention.flex_attention import (
create_block_mask as create_block_mask_flex,
)
from torchdata.stateful_dataloader import Stateful
-from forge.data import CROSS_ENTROPY_IGNORE_IDX
-from forge.data.dataset_metrics import AggregationType, Metric
-
from .dataset import DatasetInfo, InfiniteTuneIterableDataset
logger = logging.getLogger(__name__)
@@ -329,13 +329,13 @@ def _reset_packer_state(self) -> None:
self._buffer.clear()
# current_pack: the current pack being built
- self._current_pack: Optional[dict[str, list]] = None
+ self._current_pack: dict[str, list] | None = None
# current_pack_size: the number of tokens in the current pack
self._current_pack_size: int = 0
# iterator: the iterator over the dataset
- self._iterator: Optional[Iterator[SampleType]] = None
+ self._iterator: Iterator[SampleType] | None = None
# current_doc_id_in_pack: the document ID to use for the next sample
self._current_doc_id_in_pack: int = 0
@@ -343,9 +343,6 @@ def _reset_packer_state(self) -> None:
# exhausted: whether the dataset is exhausted
self._exhausted: bool = False
- # resuming: whether the packer is resuming from a checkpoint
- self._resuming: bool = False
-
def _fill_buffer(self, iterator: Iterator[SampleType]) -> None:
"""
Fills the buffer with samples from the dataset.
@@ -367,7 +364,7 @@ def _fill_buffer(self, iterator: Iterator[SampleType]) -> None:
except StopIteration:
self._exhausted = True
- def _find_next_fitting_sample(self, remaining_size: int) -> Optional[int]:
+ def _find_next_fitting_sample(self, remaining_size: int) -> int | None:
"""
Find the first sample in the buffer that fits in the remaining space.
@@ -375,7 +372,7 @@ def _find_next_fitting_sample(self, remaining_size: int) -> Optional[int]:
remaining_size (int): The remaining space in the current pack.
Returns:
- Optional[int]: The index of the sample in the buffer, or None if no sample fits.
+ int | None: The index of the sample in the buffer, or None if no sample fits.
Example:
self._buffer = deque([(sample1, 200), (sample2, 100), (sample3, 48), (sample4, 200)])
@@ -397,7 +394,7 @@ def _find_next_fitting_sample(self, remaining_size: int) -> Optional[int]:
return i
return None
- def _build_one_pack(self, iterator: Iterator[SampleType]) -> Optional[SampleDict]:
+ def _build_one_pack(self, iterator: Iterator[SampleType]) -> SampleDict | None:
"""
Builds a pack of samples from the buffer.
@@ -405,7 +402,7 @@ def _build_one_pack(self, iterator: Iterator[SampleType]) -> Optional[SampleDict
iterator (Iterator[SampleType]): The iterator over the dataset.
Returns:
- Optional[SampleDict]: The pack of samples, or None if the dataset is exhausted.
+ SampleDict | None: The pack of samples, or None if the dataset is exhausted.
"""
# Start a new pack if necessary
if self._current_pack is None:
@@ -452,19 +449,11 @@ def __iter__(self) -> Iterator[SampleDict]:
if not isinstance(self.dataset, Iterable):
raise TypeError("Dataset is not an iterable")
- if not self._resuming:
- self._reset_packer_state()
- self._iterator = iter(self.dataset)
-
- # If resuming, the iterator must be recreated from the loaded state
- if self._iterator is None:
- self._iterator = iter(self.dataset)
-
- self._resuming = False # Consume the resume flag
+ self._reset_packer_state()
+ self._iterator = iter(self.dataset)
# Main packing loop
while True:
-
# Stop if the source is exhausted and there's no data left to pack
if self._exhausted and not self._buffer and self._current_pack_size == 0:
break
@@ -502,7 +491,6 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
raise ValueError("Dataset is not stateful.")
self._reset_packer_state()
- self._resuming = True
class TextPacker(Packer[SampleDict]):
@@ -605,13 +593,13 @@ def finalize_pack(
# Add padding percentage metric
if target_tokens_per_pack > 0:
padding_pct = round(num_padding * 100 / target_tokens_per_pack, 2)
- padding_metric = Metric(
- source=self.dataset_name,
- metric_name="pct_of_tokens_padded",
- value=padding_pct,
- agg_type=AggregationType.MEAN,
+ pack["metrics"].append(
+ Metric(
+ key=f"dataset/{self.dataset_name}/pct_of_tokens_padded",
+ value=padding_pct,
+ reduction=Reduce.MEAN,
+ )
)
- pack["metrics"].append(padding_metric)
# Concatenate tensor lists and handle other keys
result = {
@@ -635,7 +623,7 @@ def finalize_pack(
if pack["input_pos"]
else torch.empty(0, dtype=torch.long)
),
- # "metrics": pack["metrics"],
+ "metrics": pack["metrics"],
}
# Handle arbitrary keys that aren't tensors - keep as lists
@@ -853,13 +841,13 @@ def finalize_pack(
# Add padding percentage metric
if target_tokens_per_pack > 0:
padding_pct = round(num_padding * 100 / target_tokens_per_pack, 2)
- padding_metric = Metric(
- source=self.dataset_name,
- metric_name="pct_of_tokens_padded",
- value=padding_pct,
- agg_type=AggregationType.MEAN,
+ pack["metrics"].append(
+ Metric(
+ key=f"dataset/{self.dataset_name}/pct_of_tokens_padded",
+ value=padding_pct,
+ reduction=Reduce.MEAN,
+ )
)
- pack["metrics"].append(padding_metric)
# Concatenate tensor lists and handle other keys
result = {
diff --git a/src/forge/data/datasets/sft_dataset.py b/src/forge/data/datasets/sft_dataset.py
index e6f6edcfb..db16ca551 100644
--- a/src/forge/data/datasets/sft_dataset.py
+++ b/src/forge/data/datasets/sft_dataset.py
@@ -4,29 +4,28 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
-from typing import Any, Callable, Optional
+from typing import Any, Callable
import torch
+import torch.distributed as dist
from forge.data import CROSS_ENTROPY_IGNORE_IDX
-from forge.data.dataset_metrics import DefaultTrainingMetricTransform
+from forge.data.metric_transform import DefaultDatasetMetricTransform
from forge.data.utils import mask_messages, TuneMessage
-from forge.interfaces import Transform
from .hf_dataset import HfIterableDataset
-class AlpacaToMessages(Transform):
+class AlpacaToMessages:
"""
Message transform class for Alpaca-style datasets with "instruction", "input", and "output"
(or equivalent fields specified in column_map) columns. User messages are formed from the
instruction + input columns and assistant messages are formed from the output column. Prompt
templating is conditional on the presence of the "input" column, and thus is handled directly
- in this transform class instead of a dedicated :class:`~torchtune.data.PromptTemplate` class
- due to this custom logic.
+ in this transform class.
Args:
- column_map (Optional[dict[str, str]]): a mapping to change the expected "instruction", "input",
+ column_map (dict[str, str] | None): a mapping to change the expected "instruction", "input",
and "output" column names to the actual column names in the dataset. Default is None,
keeping the default column names.
masking_strategy (str): masking strategy to use for model training.
@@ -45,7 +44,7 @@ class AlpacaToMessages(Transform):
def __init__(
self,
- column_map: Optional[dict[str, str]] = None,
+ column_map: dict[str, str] | None = None,
masking_strategy: str = "train_on_all",
):
self.masking_strategy = masking_strategy
@@ -125,7 +124,6 @@ class SFTOutputTransform:
"""
def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
-
# Sanity checks
if not isinstance(sample["tokens"], torch.Tensor):
sample["tokens"] = torch.tensor(sample["tokens"])
@@ -154,31 +152,33 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
def sft_iterable_dataset(
- model_transform: Transform,
+ model_transform: Callable[[dict[str, Any]], dict[str, Any]],
*,
weight: int = 1,
- message_transform: Transform,
- shuffle_buffer_size: Optional[int] = 1000,
+ message_transform: Callable[[dict[str, Any]], dict[str, Any]],
+ shuffle_buffer_size: int | None = 1000,
seed: int = 42,
num_shards_per_rank: int = 64,
- dataset_name: Optional[str] = None,
- filter_fn: Optional[Callable] = None,
- filter_kwargs: Optional[dict[str, Any]] = None,
+ dataset_name: str | None = None,
+ filter_fn: Callable | None = None,
+ filter_kwargs: dict[str, Any] | None = None,
+ dp_mesh: dist.ProcessGroup | None = None,
**load_dataset_kwargs: dict[str, Any],
) -> HfIterableDataset:
"""
Creates an SFT-ready iterable dataset with appropriate output transform.
Args:
- model_transform (Transform): Usually the tokenizer
+ model_transform (Callable): Usually the tokenizer
weight (int): Weight of the dataset. Used for sampling when interleaving datasets.
- message_transform (Transform): Transform to convert raw data to messages
- shuffle_buffer_size (Optional[int]): Buffer size for shuffling
+ message_transform (Callable): Transform to convert raw data to messages
+ shuffle_buffer_size (int | None): Buffer size for shuffling
seed (int): Random seed for shuffling
num_shards_per_rank (int): Target shards per worker
- dataset_name (Optional[str]): Name for metrics namespacing
- filter_fn (Optional[Callable]): Filter function
- filter_kwargs (Optional[dict[str, Any]]): Filter function kwargs
+ dataset_name (str | None): Name for metrics namespacing
+ filter_fn (Callable | None): Filter function
+ filter_kwargs (dict[str, Any] | None): Filter function kwargs
+ dp_mesh (dist.ProcessGroup | None): Data parallel process group for sharding (None for single process)
**load_dataset_kwargs (dict[str, Any]): Args passed to load_dataset
Returns:
@@ -200,7 +200,7 @@ def sft_iterable_dataset(
message_transform=message_transform,
model_transform=model_transform,
output_transform=output_transform,
- metric_transform=DefaultTrainingMetricTransform(),
+ metric_transform=DefaultDatasetMetricTransform(),
shuffle_buffer_size=shuffle_buffer_size,
weight=weight,
seed=seed,
@@ -208,5 +208,6 @@ def sft_iterable_dataset(
dataset_name=dataset_name,
filter_fn=filter_fn,
filter_kwargs=filter_kwargs,
+ dp_mesh=dp_mesh,
**load_dataset_kwargs,
)
diff --git a/src/forge/data/metric_transform.py b/src/forge/data/metric_transform.py
new file mode 100644
index 000000000..cbbc04020
--- /dev/null
+++ b/src/forge/data/metric_transform.py
@@ -0,0 +1,113 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Any
+
+from forge.observability.metrics import Metric, Reduce
+
+
+class MetricTransform:
+ """
+ Base class for transforms that collect observability metrics from dataset samples.
+
+ This class provides a foundation for implementing dataset-level metric collection
+ during data processing pipelines. Subclasses should override the __call__ method
+ to add specific metrics to each sample that passes through the transform.
+
+ Metrics are collected as `forge.observability.metrics.Metric` objects and made available
+ in batch["metrics"].
+
+ Attributes:
+ source (str, optional): The source name for metrics, typically the dataset name.
+ This is used as a prefix in metric keys to distinguish metrics from different
+ data sources.
+
+ Example:
+ >>> transform = SomeMetricTransform()
+ >>> transform.set_source("training_data")
+ >>> processed_sample = transform(sample)
+ >>> # Metrics are automatically added to sample["metrics"]
+ """
+
+ def __init__(self):
+ self.source = None
+
+ def set_source(self, source: str):
+ """Set the source name for metrics (typically the dataset name)."""
+ self.source = source
+
+ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Transform a sample by adding metrics to it."""
+ return sample
+
+
+class DefaultDatasetMetricTransform(MetricTransform):
+ """
+ Collects basic dataset processing metrics during data pipeline execution.
+
+ Metrics collected:
+ - samples_processed: Total number of samples that have passed through this transform (SUM)
+ - tokens_processed: Total number of tokens processed across all samples (SUM)
+ - mean_seq_len: Average sequence length across samples (MEAN)
+ - max_seq_len: Maximum sequence length observed (MAX)
+ - min_seq_len: Minimum sequence length observed (MIN)
+
+ Note: Token-related metrics are only collected if the sample contains a 'tokens' field.
+ Sequence length is measured as the number of tokens in each sample.
+
+ Example:
+ >>> collector = DefaultDatasetMetricTransform()
+ >>> collector.set_source("training_data")
+ >>> sample = {"tokens": ["hello", "world"]}
+ >>> processed_sample = collector(sample)
+ >>> # Metrics are automatically added to processed_sample["metrics"]
+ """
+
+ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
+ if "metrics" not in sample:
+ sample["metrics"] = []
+
+ source_name = self.source or "unnamed_ds"
+
+ # Add samples_processed metric
+ sample["metrics"].append(
+ Metric(
+ key=f"dataset/{source_name}/samples_processed",
+ value=1,
+ reduction=Reduce.SUM,
+ )
+ )
+
+ # Add token-based metrics if tokens are present
+ if "tokens" in sample:
+ token_count = len(sample.get("tokens", []))
+
+ sample["metrics"].extend(
+ [
+ Metric(
+ key=f"dataset/{source_name}/tokens_processed",
+ value=token_count,
+ reduction=Reduce.SUM,
+ ),
+ Metric(
+ key=f"dataset/{source_name}/mean_seq_len",
+ value=token_count,
+ reduction=Reduce.MEAN,
+ ),
+ Metric(
+ key=f"dataset/{source_name}/max_seq_len",
+ value=token_count,
+ reduction=Reduce.MAX,
+ ),
+ Metric(
+ key=f"dataset/{source_name}/min_seq_len",
+ value=token_count,
+ reduction=Reduce.MIN,
+ ),
+ ]
+ )
+
+ return sample
diff --git a/src/forge/data/sharding.py b/src/forge/data/sharding.py
deleted file mode 100644
index 2027f8a43..000000000
--- a/src/forge/data/sharding.py
+++ /dev/null
@@ -1,142 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-import torch
-
-
-class VLLMSharding:
- """
- vLLM specific tensor parallel sharding strategy.
- """
-
- def __init__(self, tensor_parallel_size: int, rank: int):
- self.tensor_parallel_size = tensor_parallel_size
- self.rank = rank
-
- def load_from_source_to_target(
- self,
- param_name: str,
- source_tensor: torch.Tensor,
- target_tensor: torch.Tensor,
- ) -> None:
- """
- Copy a source tensor to a target tensor, handling sharding and replication.
- """
- # Determine sharding strategy for this parameter
- shard_dim, is_sharded = self._get_tensor_parallel_sharding_strategy(param_name)
-
- if not is_sharded:
- # Parameter is replicated - shapes should match exactly
- if source_tensor.shape != target_tensor.shape:
- raise ValueError(
- f"Replicated parameter {param_name} has mismatched shapes: "
- f"{source_tensor.shape} vs {target_tensor.shape}, skipping"
- )
-
- # Direct copy for replicated parameters
- target_tensor.copy_(source_tensor)
- else:
- # Need to shard the full tensor
- sharded_tensor = self._calculate_tensor_shard(
- source_tensor, shard_dim, self.tensor_parallel_size, self.rank
- )
-
- if sharded_tensor.shape != target_tensor.shape:
- raise ValueError(
- f"Calculated shard for {param_name} has wrong shape: "
- f"{sharded_tensor.shape} vs expected {target_tensor.shape}, skipping"
- )
-
- target_tensor.copy_(sharded_tensor)
-
- def _get_tensor_parallel_sharding_strategy(
- self, param_name: str
- ) -> tuple[int, bool]:
- """
- Determine the sharding strategy for a parameter in tensor parallel setup.
-
- Returns:
- tuple[int, bool]: (shard_dimension, is_sharded)
- - shard_dimension: Which dimension to shard (0 or 1)
- - is_sharded: Whether this parameter should be sharded at all
-
- Based on vLLM's tensor parallel implementation for LLaMA models:
- - Embedding layers: shard along vocab dimension (dim 0)
- - Attention projections: qkv_proj shard along hidden dimension (dim 0), o_proj along input dimension (dim 1)
- - MLP projections: gate/up_proj shard along hidden dimension (dim 0), down_proj along input dimension (dim 1)
- - Layer norms: not sharded (replicated)
- - Output layer: shard along vocab dimension (dim 0)
- """
- # Parameters that are not sharded (replicated across all tensor parallel ranks)
- if any(keyword in param_name for keyword in ["norm", "bias", "rotary_emb"]):
- return 0, False
-
- # Embedding layers - shard along vocab dimension (dim 0)
- if "embed_tokens" in param_name or "lm_head" in param_name:
- return 0, True
-
- # Attention projections
- if "qkv_proj" in param_name:
- # Input projections: shard output dimension (dim 0)
- return 0, True
- elif "o_proj" in param_name:
- # Output projection: shard input dimension (dim 1)
- return 1, True
-
- # MLP projections
- elif any(
- proj in param_name for proj in ["gate_proj", "up_proj", "gate_up_proj"]
- ):
- # Input projections: shard output dimension (dim 0)
- return 0, True
- elif "down_proj" in param_name:
- # Output projection: shard input dimension (dim 1)
- return 1, True
-
- # Default: try to infer from tensor shape patterns
- return 0, True
-
- def _calculate_tensor_shard(
- self,
- full_tensor: torch.Tensor,
- shard_dim: int,
- tensor_parallel_size: int,
- rank: int,
- ) -> torch.Tensor:
- """
- Calculate the shard of a full tensor for the current tensor parallel rank.
-
- Args:
- full_tensor: The full tensor to shard
- shard_dim: Which dimension to shard along (0 or 1)
- tensor_parallel_size: Number of tensor parallel ranks
- rank: Current rank (will be modulo by tensor_parallel_size)
-
- Returns:
- torch.Tensor: The sharded tensor for this rank
- """
- tp_rank = rank % tensor_parallel_size
- tensor_size = full_tensor.shape[shard_dim]
-
- if tensor_size % tensor_parallel_size != 0:
- raise ValueError(
- f"Cannot shard tensor dimension {shard_dim} with size {tensor_size} "
- f"across {tensor_parallel_size} ranks: not evenly divisible"
- )
-
- shard_size = tensor_size // tensor_parallel_size
- start_idx = tp_rank * shard_size
- end_idx = start_idx + shard_size
-
- # Create index tensor for the shard range
- indices = torch.arange(start_idx, end_idx, device=full_tensor.device)
-
- if shard_dim == 0:
- return torch.index_select(full_tensor, 0, indices)
- elif shard_dim == 1:
- return torch.index_select(full_tensor, 1, indices)
- else:
- raise ValueError(f"Unsupported shard dimension: {shard_dim}")
diff --git a/src/forge/data/tokenizer.py b/src/forge/data/tokenizer.py
index 3cb90f79c..060a26e9f 100644
--- a/src/forge/data/tokenizer.py
+++ b/src/forge/data/tokenizer.py
@@ -8,19 +8,19 @@
from typing import Any, Optional
import jinja2
-from jinja2 import StrictUndefined
-
-from tokenizers import Tokenizer
from forge.data.utils import truncate
from forge.interfaces import BaseTokenizer, ModelTokenizer
from forge.types import Message
+from jinja2 import StrictUndefined
+
+from tokenizers import Tokenizer
class HuggingFaceBaseTokenizer(BaseTokenizer):
"""
A wrapper around Hugging Face tokenizers. See https://github.com/huggingface/tokenizers
- This can be used to load from a Hugging Face tokenizer.json file into a torchtune BaseTokenizer.
+ This can be used to load from a Hugging Face tokenizer.json file into a forge BaseTokenizer.
This class will load the tokenizer.json file from tokenizer_json_path. It will
attempt to infer BOS and EOS token IDs from config.json if possible, and if not
@@ -28,8 +28,8 @@ class HuggingFaceBaseTokenizer(BaseTokenizer):
Args:
tokenizer_json_path (str): Path to tokenizer.json file
- tokenizer_config_json_path (Optional[str]): Path to tokenizer_config.json file. Default: None
- generation_config_path (Optional[str]): Path to generation_config.json file.
+ tokenizer_config_json_path (str | None): Path to tokenizer_config.json file. Default: None
+ generation_config_path (str | None): Path to generation_config.json file.
Default: None
Raises:
@@ -40,8 +40,8 @@ def __init__(
self,
tokenizer_json_path: str,
*,
- tokenizer_config_json_path: Optional[str] = None,
- generation_config_path: Optional[str] = None,
+ tokenizer_config_json_path: str | None = None,
+ generation_config_path: str | None = None,
):
self.tokenizer = Tokenizer.from_file(tokenizer_json_path)
if not (tokenizer_config_json_path or generation_config_path):
@@ -61,7 +61,7 @@ def __init__(
self._infer_bos_eos_tokens()
self._infer_should_add_bos_eos()
- def _get_token_from_config(self, config: dict[str, Any], key: str) -> str:
+ def _get_token_from_config(self, config: dict[str, Any], key: str) -> Optional[str]:
"""
HF BOS/EOS tokens are either stored as e.g. {'bos_token': 5}
or {'bos_token': {'content': 5, ...}}. This utility handles both.
@@ -72,7 +72,7 @@ def _get_token_from_config(self, config: dict[str, Any], key: str) -> str:
raise ValueError(f"Could not parse {key} from config")
token = token["content"]
else:
- if not isinstance(token, str):
+ if token is not None and not isinstance(token, str):
raise ValueError(f"Could not parse {key} from config")
return token
@@ -137,7 +137,12 @@ def encode(
list[int]: The list of token ids.
"""
token_ids = self.tokenizer.encode(text).ids
- if add_bos and not self.hf_adds_bos and self.bos_token not in text:
+ if (
+ add_bos
+ and not self.hf_adds_bos
+ and self.bos_token is not None
+ and self.bos_token not in text
+ ):
token_ids.insert(0, self.bos_id)
if add_eos and not self.hf_adds_eos:
token_ids.append(self.eos_id)
@@ -205,13 +210,13 @@ class HuggingFaceModelTokenizer(ModelTokenizer):
Then, it will load all special tokens and chat template from tokenizer config file.
It can be used to tokenize messages with correct chat template, and it eliminates the requirement of
- the specific ModelTokenizer and custom PromptTemplate.
+ the specific ModelTokenizer.
Args:
tokenizer_json_path (str): Path to tokenizer.json file
- tokenizer_config_json_path (Optional[str]): Path to tokenizer_config.json file. Default: None
- generation_config_path (Optional[str]): Path to generation_config.json file.
- Default: None
+ tokenizer_config_json_path (str | None): Path to tokenizer_config.json file. Default: None
+ generation_config_path (str | None): Path to generation_config.json file. Default: None
+ chat_template_path (str | None): Path to chat_template.jinja file. Default: None
truncation_type (str): type of truncation to apply, either "left" or "right".
Default is "right".
"""
@@ -220,8 +225,9 @@ def __init__(
self,
tokenizer_json_path: str,
*,
- tokenizer_config_json_path: Optional[str] = None,
- generation_config_path: Optional[str] = None,
+ tokenizer_config_json_path: str | None = None,
+ generation_config_path: str | None = None,
+ chat_template_path: str | None = None,
truncation_type: str = "right",
):
self.base_tokenizer = HuggingFaceBaseTokenizer(
@@ -240,7 +246,13 @@ def __init__(
# It is used sometimes in HF chat_templates
_env.globals["raise_exception"] = self._raise_helper
- self.template = _env.from_string(config["chat_template"])
+
+ if chat_template_path:
+ with open(chat_template_path, "r") as f:
+ self.template = _env.from_string(f.read())
+ else:
+ self.template = _env.from_string(config["chat_template"])
+
self.truncation_type = truncation_type
self.special_tokens_mapping = {}
@@ -262,8 +274,14 @@ def extract_top_level_variables(self, config):
def render_template(
self, messages: list[dict[str, str]], add_eos: bool = True
) -> str:
+ # Need to set tool_calls to something for qwen chat_template
+ if self.base_tokenizer.config["tokenizer_class"] == "Qwen2Tokenizer":
+ for message in messages:
+ if "tool_calls" not in message:
+ message["tool_calls"] = {}
rendered = self.template.render(
messages=messages,
+ tools=None,
add_generation_prompt=add_eos,
**self.special_tokens_mapping, # We assume that the naming is consistent
**self.top_level_variables,
@@ -274,7 +292,7 @@ def tokenize_messages(
self,
messages: list[Message],
add_eos: bool = True,
- max_seq_len: Optional[int] = None,
+ max_seq_len: int | None = None,
) -> tuple[list[int], list[bool]]:
tokenized_messages = []
mask = []
@@ -291,10 +309,13 @@ def tokenize_messages(
add_eos=add_eos if i == len(messages) - 1 else False,
)
- current_tokens = self.base_tokenizer.encode(rendered, add_eos=False)
+ current_tokens = self.base_tokenizer.encode(
+ rendered, add_bos=False, add_eos=False
+ )
if (
- self.base_tokenizer.bos_token in rendered
+ self.base_tokenizer.bos_token is not None
+ and self.base_tokenizer.bos_token in rendered
and self.base_tokenizer.hf_adds_bos
):
del current_tokens[0]
diff --git a/src/forge/data/utils.py b/src/forge/data/utils.py
index e87ff50c9..e335c23e4 100644
--- a/src/forge/data/utils.py
+++ b/src/forge/data/utils.py
@@ -4,13 +4,17 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+import logging
from enum import Enum
-from typing import Any, Literal, Optional, Union
+from typing import Any, Iterator, Literal, Union
import torch
+import torch.distributed as dist
from torch.nn.attention.flex_attention import BlockMask
+logger = logging.getLogger(__name__)
+
CROSS_ENTROPY_IGNORE_IDX = -100
Role = Literal[
@@ -32,7 +36,7 @@ class TuneMessage:
"""
This class represents individual messages in a fine-tuning dataset. It supports
text-only content, text with interleaved images, and tool calls. The
- :class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` will tokenize
+ :class:`~forge.interfaces.ModelTokenizer` will tokenize
the content of the message using ``tokenize_messages`` and attach the appropriate
special tokens based on the flags set in this class.
@@ -61,8 +65,7 @@ class TuneMessage:
- All ipython messages (tool call returns) should set ``eot=False``.
Note:
- TuneMessage class expects any image content to be a ``torch.Tensor``, as output
- by e.g. :func:`~torchtune.data.load_image`
+ TuneMessage class expects any image content to be a ``torch.Tensor``.
"""
def __init__(
@@ -118,7 +121,7 @@ def __repr__(self) -> str:
def truncate(
tokens: list[Any],
max_seq_len: int,
- eos_id: Optional[Any] = None,
+ eos_id: Any | None = None,
truncation_type: str = "right",
) -> list[Any]:
"""
@@ -128,7 +131,7 @@ def truncate(
Args:
tokens (list[Any]): list of tokens to truncate
max_seq_len (int): maximum length of the list
- eos_id (Optional[Any]): token to replace the last token with. If None, the
+ eos_id (Any | None): token to replace the last token with. If None, the
last token will not be replaced. Default is None.
truncation_type (str): type of truncation to apply, either "left" or "right".
Default is "right".
@@ -214,3 +217,118 @@ def batch_to_device(batch: dict, device: torch.device) -> None:
f"Tensor, or BlockMask with flexattention enabled. "
f'Got key "{k}" with value of type {type(v)}'
)
+
+
+class StopAfterOneEpoch:
+ """Wraps an iterator, e.g. dataloader, and stops iterating after a rank shows that an epoch has been completed.
+
+ In distributed eval, we may have len(dataset) % num_ranks != 0. This means that some ranks may be on epoch 0
+ while others are already in epoch 1. To avoid hangs, all ranks *must* stop at the same time, requiring communication.
+
+ This function minimzes this impact by fetching one batch in advance and perfoming overlapping async all_reduce.
+
+ Assumes batch contains field "metrics" with at least one Metric containing "num_epochs" in its key, as it is done in
+ `forge.src.data.datasets.HfIterableDataset`.
+
+ Args:
+ iter (Iterator): Iterator over dataloader batches
+ device (torch.device): Device for synchronizing tensors
+ dp_mesh (dist.ProcessGroup | None): Data parallel process group (None for single process)
+ """
+
+ def __init__(
+ self,
+ iter: Iterator,
+ device: torch.device,
+ dp_mesh: dist.ProcessGroup | None = None,
+ ):
+ self.iter = iter
+ self.device = device
+ self.dp_mesh = dp_mesh
+
+ # Prefetch first batch for pipeline-style execution
+ self._next_batch = next(iter)
+
+ # Track pending async epoch sync
+ self._epoch_tensor: torch.Tensor | None = None
+ self._pending_work: Any = None
+ self._should_stop = False
+
+ def __iter__(self):
+ return self
+
+ def __next__(self) -> dict:
+ """Get next batch from current epoch.
+
+ Returns:
+ Batch dict guaranteed to be from current epoch
+
+ Raises:
+ StopIteration: When epoch completes across all ranks
+ """
+ # Check if previous epoch sync completed
+ if self._pending_work is not None:
+ self._pending_work.wait()
+ if self._epoch_tensor.item() > 0:
+ self._should_stop = True
+ self._pending_work = None
+ self._epoch_tensor = None
+
+ if self._should_stop:
+ logger.debug("Eval epoch completed. Stopping data iterator.")
+ raise StopIteration
+
+ # Get current batch
+ current_batch = self._next_batch
+ current_epoch = extract_epoch_from_batch(current_batch)
+
+ # Prefetch next batch and check for epoch change
+ self._next_batch = next(self.iter)
+ next_epoch = extract_epoch_from_batch(self._next_batch)
+ epoch_changed = next_epoch > current_epoch
+
+ # Start async epoch sync
+ if torch.distributed.is_initialized():
+ self._epoch_tensor = torch.tensor([int(epoch_changed)], device=self.device)
+ self._pending_work = torch.distributed.all_reduce(
+ self._epoch_tensor,
+ op=torch.distributed.ReduceOp.MAX,
+ group=self.dp_mesh,
+ async_op=True,
+ )
+ elif epoch_changed:
+ # if not distributed, just update the flag directly
+ self._should_stop = True
+
+ return current_batch
+
+
+def extract_epoch_from_batch(batch: dict) -> int:
+ """Extract epoch number from batch metrics. Useful to detect epoch changes during validation.
+
+ Assumes batch contains field "metrics" with at least one Metric containing "num_epochs" in its key, as it is done in
+ `forge.src.data.datasets.HfIterableDataset`.
+
+ Args:
+ batch (dict): Batch dictionary with 'metrics' field
+
+ Returns:
+ int: Max epoch number from metrics
+
+ Raises:
+ ValueError: If metrics key is missing or no metric with 'num_epochs' found
+ """
+ if "metrics" not in batch:
+ raise ValueError(
+ "Batch missing 'metrics' field. Cannot extract epoch from batch."
+ )
+
+ # Match metrics where 'num_epochs' appears in the key (handles prefixed keys like 'dataset/name/num_epochs')
+ epochs = [metric.value for metric in batch["metrics"] if "num_epochs" in metric.key]
+ if epochs:
+ return max(epochs)
+
+ raise ValueError(
+ f"No 'num_epochs' metric found in batch. Got metrics: "
+ f"{[m.key for m in batch['metrics']]}"
+ )
diff --git a/src/forge/data_models/completion.py b/src/forge/data_models/completion.py
index 00dae9022..123caee55 100644
--- a/src/forge/data_models/completion.py
+++ b/src/forge/data_models/completion.py
@@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
-from typing import Optional
+from typing import Any
import torch
@@ -29,10 +29,13 @@ class Completion:
token_ids: torch.Tensor
# the log probabilities of the target tokens
- logprobs: Optional[torch.Tensor] = None
+ logprobs: torch.Tensor | None = None
# the reason for stopping the generation
stop_reason: str | None = None
# the version identifier of the model when the generation was performed
generator_version: int | None = None
+
+ # extra information that might be useful for debugging
+ metadata: dict[str, Any] | None = None
diff --git a/src/forge/data_models/episode.py b/src/forge/data_models/episode.py
deleted file mode 100644
index 835373d18..000000000
--- a/src/forge/data_models/episode.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-from dataclasses import dataclass
-from typing import Optional, Sequence
-
-import torch
-
-from forge.data_models.scored_completion import ScoredCompletion
-
-
-@dataclass
-class Episode:
- """
- The Episode data class to be used by the trainer.
-
- Episodes are usually generated from a scored completion and running various post processing steps.
- """
-
- # Concatenated prompt and sample token ids.
- ids: torch.Tensor
-
- # The mask for the target ids, 0 for prompt tokens, 1 for sample tokens.
- mask: torch.Tensor
-
- # The weight to apply to the loss of each target token. It's normally computed
- # from the advantage and the reward.
- weights: torch.Tensor
-
- # The log probabilities of the target tokens, for prompt part it's set to 0,
- # for generation part it's computed from the Generator/Sampler.
- log_probs: Optional[torch.Tensor] = None
-
- # TODO: add more fields as required
- state: str = ""
-
-
-def from_scored_completion(scored_completion: ScoredCompletion) -> Episode:
- """Converts a ScoredCompletion to an Episode."""
- prompt_ids = scored_completion.completion.prompt_ids
- token_ids = scored_completion.completion.token_ids
- log_probs = scored_completion.completion.log_probs
- ids = torch.cat([prompt_ids, token_ids])
- mask = torch.cat(
- [
- torch.zeros(prompt_ids.shape, dtype=torch.float32),
- torch.ones_like(token_ids, dtype=torch.float32),
- ]
- )
- advantage = scored_completion.score
- weights = mask * advantage
- log_probs = torch.cat(
- [
- torch.zeros(prompt_ids.shape, dtype=torch.float32),
- # TODO: this only works if sample.log_probs is 1
- log_probs,
- ]
- )
- return Episode(ids=ids, mask=mask, weights=weights, log_probs=log_probs)
-
-
-def from_scored_completions(
- scored_completions: Sequence[ScoredCompletion],
-) -> Sequence[Episode]:
- """Converts a sequence of ScoredCompletion to a sequence of Episodes."""
- return [from_scored_completion(sc) for sc in scored_completions]
diff --git a/src/forge/data_models/scored_completion.py b/src/forge/data_models/scored_completion.py
deleted file mode 100644
index f41ff7b59..000000000
--- a/src/forge/data_models/scored_completion.py
+++ /dev/null
@@ -1,19 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-from dataclasses import dataclass
-
-from forge.data_models.completion import Completion
-
-
-@dataclass
-class ScoredCompletion:
- """A completion with an associated score (from a reward model or human)."""
-
- completion: Completion
- score: float # akin to reward
-
- # TODO: add more fields as needed.
diff --git a/src/forge/env.py b/src/forge/env.py
new file mode 100644
index 000000000..b698b8013
--- /dev/null
+++ b/src/forge/env.py
@@ -0,0 +1,115 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Centralized constants for environment variable names used in the project."""
+
+import os
+from dataclasses import dataclass
+from typing import Any
+
+
+@dataclass
+class EnvVar:
+ """Configuration for an environment variable."""
+
+ name: str
+ default: Any
+ description: str
+
+ def get_value(self) -> Any:
+ """Get the value of this environment variable with fallback to default.
+
+ Returns:
+ The environment variable value, auto-converted to the appropriate type
+ based on the default value, or the default value if not set.
+
+ Example:
+ >>> DISABLE_PERF_METRICS.get_value()
+ False
+ >>> os.environ["DISABLE_PERF_METRICS"] = "true"
+ >>> DISABLE_PERF_METRICS.get_value()
+ True
+ """
+ value = os.environ.get(self.name)
+
+ if value is None:
+ return self.default
+
+ # Auto-convert based on the default type
+ if isinstance(self.default, bool):
+ return value.lower() in ("true", "1", "yes")
+ elif isinstance(self.default, int):
+ return int(value)
+ elif isinstance(self.default, float):
+ return float(value)
+ else:
+ # Return as string for other types
+ return value
+
+
+# Environment variable definitions
+DISABLE_PERF_METRICS = EnvVar(
+ name="DISABLE_PERF_METRICS",
+ default=False,
+ description="Performance metrics in forge.observability.perf_tracker.py becomes no-op",
+)
+
+METRIC_TIMER_USES_GPU = EnvVar(
+ name="METRIC_TIMER_USES_GPU",
+ default=None,
+ description=(
+ "Force all timing methods in forge.observability.perf_tracker.py "
+ "to use CPU timer if False or GPU timer if True. If unset (None), defaults to the timer parameter."
+ ),
+)
+
+FORGE_DISABLE_METRICS = EnvVar(
+ name="FORGE_DISABLE_METRICS",
+ default=False,
+ description=(
+ "Makes forge.observability.metrics.record_metric a no-op and disables spawning LocalFetcherActor"
+ " in get_or_create_metric_logger"
+ ),
+)
+
+MONARCH_STDERR_LEVEL = EnvVar(
+ name="MONARCH_STDERR_LOG",
+ default="warning",
+ description="Sets Monarch's stderr log level, i.e. set to 'info' or 'debug'",
+)
+
+RUST_BACKTRACE = EnvVar(
+ name="RUST_BACKTRACE",
+ default="full",
+ description="Sets the level for Rust-level failures. I.e. set to full for full stack traces.",
+)
+
+MONARCH_MESSAGE_DELIVERY_TIMEOUT = EnvVar(
+ name="HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS",
+ default=600,
+ description="Sets the timeout limit for Monarch's actor message delivery in seconds.",
+)
+
+MONARCH_MAX_FRAME_LENGTH = EnvVar(
+ name="HYPERACTOR_CODE_MAX_FRAME_LENGTH",
+ default=1073741824,
+ description="Sets the maximum frame length for Monarch's actor message delivery in bytes.",
+)
+
+TORCHSTORE_USE_RDMA = EnvVar(
+ name="TORCHSTORE_RDMA_ENABLED",
+ default=1,
+ description="Whether or not to use RDMA in TorchStore.",
+)
+
+
+def all_env_vars() -> list[EnvVar]:
+ """Retrieves all registered environment variable names."""
+ env_vars = []
+ for _, value in globals().items():
+ if isinstance(value, EnvVar):
+ env_vars.append(value)
+ return env_vars
diff --git a/src/forge/env_constants.py b/src/forge/env_constants.py
deleted file mode 100644
index 3adcdfc41..000000000
--- a/src/forge/env_constants.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-"""Centralized constants for environment variable names used in the project."""
-
-# Performance metrics in forge.observability.perf_tracker.py becomes no-op
-DISABLE_PERF_METRICS = "DISABLE_PERF_METRICS"
-
-# Force all timing methods in forge.observability.perf_tracker.py to use
-# CPU timer if False or GPU timer if True. If unset, defaults to the assigned value to the function.
-METRIC_TIMER_USES_CUDA = "METRIC_TIMER_USES_CUDA"
-
-# Makes forge.observability.metrics.record_metric a no-op
-FORGE_DISABLE_METRICS = "FORGE_DISABLE_METRICS"
diff --git a/src/forge/envs/chat.py b/src/forge/envs/chat.py
deleted file mode 100644
index 24a5981a6..000000000
--- a/src/forge/envs/chat.py
+++ /dev/null
@@ -1,212 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-from dataclasses import dataclass, field
-
-import torch
-
-from forge.interfaces import Environment, Message, ModelTokenizer, Transform
-
-from forge.types import Action, Observation, State
-
-
-@dataclass
-class ChatAction(Action):
- """Action for chat environments.
-
- Contains tokens that represent the action to be taken.
- This interfaces directly with models.
- """
-
- tokens: torch.Tensor = field(default_factory=lambda: torch.tensor([]))
-
- def __post_init__(self):
- """Validate required fields after initialization."""
- if self.tokens.numel() == 0:
- raise ValueError("tokens is required and cannot be empty")
-
-
-@dataclass
-class ChatState(State):
- """State of the ChatEnvironment containing message history."""
-
- history_messages: list[Message] = field(default_factory=list)
- history_tokens: list[torch.Tensor] = field(
- default_factory=list
- ) # Same len as messages
-
-
-@dataclass
-class ChatObservation(Observation):
- """Observation returned by ChatEnvironment.
-
- Contains the message history in Huggingface format (list of dicts with role/content)
- and the tokenized representation of the entire conversation.
-
- The environment owns the tokenizer and generates the tokens from the messages.
-
- Example:
- messages = [
- {"role": "system", "content": "You are a helpful assistant"},
- {"role": "user", "content": "How tall is the Eiffel Tower?"},
- ]
- tokens = tensor([1, 2, 3, 4, 5, ...]) # tokenized entire conversation
- """
-
- messages: list[Message] = field(default_factory=list)
- tokens: torch.Tensor = field(default_factory=lambda: torch.tensor([]))
- # Inherited fields from Observation ABC: reward, done, metadata
-
-
-class ChatEnvironment(Environment):
- """A chat-based environment for LLMs, designed as a blank canvas for conversation and RL.
-
- This environment is designed to work with language models. It provides the fundamental structure
- for managing conversation state but is intentionally minimal to allow maximum flexibility.
-
- The environment owns the tokenizer and is responsible for managing both message history and tokens.
- Actions contain only tokens that interface directly with models.
-
- Args:
- tokenizer: A tokenizer that will be used to tokenize the conversation
- system_prompt: An optional system prompt string to use during reset calls (optional)
- system_role: The role of the system (at reset time). Defaults to "system"
- """
-
- def __init__(
- self,
- tokenizer: ModelTokenizer,
- system_prompt: str | None = None,
- system_role: str = "system",
- transform: Transform | None = None,
- ):
- super().__init__(transform=transform)
-
- if not hasattr(tokenizer, "apply_chat_template"):
- raise ValueError("Tokenizer must have 'apply_chat_template' method")
- self.tokenizer = tokenizer
- self.system_prompt = system_prompt
- self.system_role = system_role
-
- self._state = ChatState()
-
- if system_prompt:
- system_message: Message = {"role": system_role, "content": system_prompt}
- self._state.history_messages.append(system_message)
- # Tokenize the system message
- system_tokens = self.tokenizer.apply_chat_template(
- conversation=[system_message], tokenize=True, return_tensors="pt" # type: ignore
- )
- self._state.history_tokens.append(system_tokens)
-
- def reset(self) -> ChatObservation:
- """Reset the environment to initial state.
-
- Returns:
- ChatObservation: Initial observation with system prompt (if any)
- """
- self._state.history_messages = []
- self._state.history_tokens = []
- if self.system_prompt:
- system_message: Message = {
- "role": self.system_role,
- "content": self.system_prompt,
- }
- self._state.history_messages = [system_message]
- # Tokenize the system message
- system_tokens = self.tokenizer.apply_chat_template(
- conversation=[system_message], tokenize=True, return_tensors="pt" # type: ignore
- )
- self._state.history_tokens = [system_tokens]
-
- return self._create_observation()
-
- def step(self, action: ChatAction) -> ChatObservation:
- """Take a step in the environment by adding tokens to the chat history.
-
- Args:
- action: A ChatAction object containing tokens.
-
- Returns:
- ChatObservation: The updated observation with the new tokens added.
- """
- # Store the tokens directly from the action
- self._state.history_tokens.append(action.tokens)
-
- # Decode tokens to text and add as a message to history
- decoded_text = self.tokenizer.decode(
- action.tokens.squeeze(), skip_special_tokens=True
- )
- assistant_message: Message = {"role": "assistant", "content": decoded_text}
- self._state.history_messages.append(assistant_message)
-
- return self._create_observation()
-
- def _create_observation(self) -> ChatObservation:
- """Create a ChatObservation from the current state.
-
- Returns both the message history and the tokens flattened as a single tensor
- ready to be used by models.
-
- Returns:
- ChatObservation: Observation with messages and flattened tokens
- """
- if self._state.history_tokens:
- flattened_tokens = torch.cat(self._state.history_tokens, dim=0)
- else:
- flattened_tokens = torch.tensor([])
-
- observation = ChatObservation(
- messages=self._state.history_messages.copy(), # Copy to prevent external mutation
- tokens=flattened_tokens,
- )
-
- transformed = self._apply_transform(observation)
- if isinstance(transformed, ChatObservation):
- return transformed
- else:
- # If transform returns base Observation, convert back to ChatObservation
- return ChatObservation(
- messages=getattr(transformed, "messages", []),
- tokens=getattr(transformed, "tokens", torch.tensor([])),
- done=transformed.done,
- reward=transformed.reward,
- )
-
- @property
- def state(self) -> ChatState:
- """Get the current state of the environment.
-
- Returns:
- ChatState: The current state.
- """
- return self._state
-
- def message_to_action(self, message: Message) -> ChatAction:
- """Convert a message dictionary to a ChatAction with tokens.
-
- Args:
- message: Dictionary with 'role' and 'content' keys
-
- Returns:
- ChatAction: A new ChatAction instance with tokenized content
-
- Raises:
- ValueError: If required keys are missing
- """
- if "role" not in message:
- raise ValueError("Message must contain a 'role' key")
- if "content" not in message:
- raise ValueError("Message must contain a 'content' key")
- if message["content"] is None:
- raise ValueError("Message content cannot be None")
-
- # Tokenize the single message
- tokens = self.tokenizer.apply_chat_template(
- conversation=[message], tokenize=True, return_tensors="pt" # type: ignore
- )
-
- return ChatAction(tokens=tokens)
diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py
index df79c302e..8a4ca06ef 100644
--- a/src/forge/interfaces.py
+++ b/src/forge/interfaces.py
@@ -7,98 +7,13 @@
from abc import ABC, abstractmethod
from typing import Any, Mapping
-from monarch.actor import endpoint
-
-from forge.controller import ForgeActor
-
-from forge.types import Action, Message, Observation, Scalar, State
-
-
-class Transform(ABC):
- """Abstract base class for observation transforms.
-
- Transforms are first-class citizens that can modify observations,
- typically to add rewards, compute metrics, or modify state.
-
- They follow a functional interface where they take an observation
- and return a (potentially modified) observation.
- """
-
- @abstractmethod
- def __call__(self, observation: Observation) -> Observation:
- """Transform an observation.
-
- Args:
- observation: The input observation to transform
-
- Returns:
- The transformed observation (may be the same instance if no changes)
- """
- pass
-
-
-class Environment(ABC):
- """Abstract base class for environments.
-
- Args:
- transform: Optional transform that modifies observations, typically to add rewards.
- Can be a Transform instance or a callable for backward compatibility.
- """
-
- def __init__(
- self,
- transform: Transform | None = None,
- ):
- self.transform = transform
-
- @abstractmethod
- def reset(self) -> Observation:
- """Reset the environment and return an initial observation."""
- pass
-
- @abstractmethod
- def step(self, action: Any) -> Observation:
- """Take a step in the environment and return an observation."""
- pass
-
- @property
- @abstractmethod
- def state(self) -> State:
- """Get the current state of the environment."""
- pass
-
- def _apply_transform(self, observation: Observation) -> Observation:
- """Apply the transform to an observation if one is provided."""
- if self.transform is not None:
- return self.transform(observation)
- return observation
-
-
-class Policy(ForgeActor, ABC):
- """Abstract interface for policies."""
-
- @endpoint
- @abstractmethod
- async def generate(self, request: Observation) -> Action:
- """Generate an action given a state/request."""
- pass
-
- @endpoint
- @abstractmethod
- async def update_weights(self, policy_version: int):
- """Update the policy weights.
-
- Args:
- policy_version: The version number to update to.
- """
- pass
+from forge.types import Message, Scalar
class BaseTokenizer(ABC):
"""
Abstract token encoding model that implements ``encode`` and ``decode`` methods.
- See :class:`~torchtune.modules.transforms.tokenizers.SentencePieceBaseTokenizer` and
- :class:`~torchtune.modules.transforms.tokenizers.TikTokenBaseTokenizer` for example implementations of this protocol.
+ See :class:`forge.data.HuggingFaceModelTokenizer for an example implementation of this protocol.
"""
@abstractmethod
@@ -133,7 +48,7 @@ def decode(self, token_ids: list[int], **kwargs: dict[str, Any]) -> str:
class ModelTokenizer(ABC):
"""
Abstract tokenizer that implements model-specific special token logic in
- the ``tokenize_messages`` method. See :class:`~torchtune.models.llama3.Llama3Tokenizer`
+ the ``tokenize_messages`` method. See :class:`forge.data.HuggingFaceModelTokenizer`
for an example implementation of this protocol.
"""
@@ -201,19 +116,3 @@ def close(self) -> None:
This will automatically be called via __del__ when the instance goes out of scope.
Logs should not be written after `close` is called.
"""
-
-
-class Reward(ABC):
- """Abstract base class for reward models."""
-
- @abstractmethod
- def __call__(self, observation: Observation) -> float:
- """Compute a reward for an observation."""
- pass
-
-
-# TODO
-# class RLLoss(ABC):
-
-# class SFTLoss(ABC): # inherit from titan loss
-# from torchtitan.components.loss import LossFunction
diff --git a/src/forge/losses/reinforce_loss.py b/src/forge/losses/reinforce_loss.py
index c2dedd530..f1c92b667 100644
--- a/src/forge/losses/reinforce_loss.py
+++ b/src/forge/losses/reinforce_loss.py
@@ -5,9 +5,9 @@
# LICENSE file in the root directory of this source tree.
import torch
-from torch import nn
-from forge.util.ops import selective_log_softmax
+from forge.util.ops import compute_logprobs
+from torch import nn
class ReinforceLoss(nn.Module):
@@ -29,7 +29,7 @@ def __init__(self):
def forward(
self, trainer_logits, target_ids, target_mask, target_weights, target_log_probs
):
- trainer_log_probs = selective_log_softmax(trainer_logits, target_ids)
+ trainer_log_probs = compute_logprobs(trainer_logits, target_ids, align=False)
target_mask = target_mask.detach()
target_weights = target_weights
target_mask_sum = target_mask.sum()
diff --git a/src/forge/observability/README.md b/src/forge/observability/README.md
new file mode 100644
index 000000000..6a653dbbd
--- /dev/null
+++ b/src/forge/observability/README.md
@@ -0,0 +1,294 @@
+# Metric Logging in Forge
+
+We aim to make distributed observability effortless. You can call `record_metric(key, val, reduce_type)` from anywhere, and it just works. We also provide memory/performance tracers, plug-and-play logging backends, and reduction types. You can visualize aggregated results globally, per-rank or as a stream. No boilerplate required - just call, flush, and visualize. Disable with `FORGE_DISABLE_METRICS=true`.
+
+## 1. Your Superpowers
+
+### 1.1 Call `record_metric` from Anywhere
+
+Simple to use, with no need to pass dictionaries around. For example, users can simply write:
+
+```python
+def my_fn():
+ record_metric(key, value, reduce)
+```
+
+Instead of:
+
+```python
+def my_fn(my_metrics):
+ my_metrics[key] = value
+ return my_metrics
+```
+
+Simple example (for a distributed one, check the next section)
+```python
+import asyncio
+from forge.observability import get_or_create_metric_logger, record_metric, Reduce
+
+async def main():
+ # Setup logger
+ mlogger = await get_or_create_metric_logger(process_name="Controller")
+ await mlogger.init_backends.call_one({"console": {"logging_mode": "global_reduce"}})
+
+ # Have this in any process
+ def my_fn(number):
+ record_metric("my_sum_metric", number, Reduce.SUM) # sum(1,2,3)
+ record_metric("my_max_metric", number, Reduce.MAX) # max(1,2,3)
+ record_metric("my_mean_metric", number, Reduce.MEAN) # mean(1,2,3)
+
+ # Accumulate metrics
+ for number in range(1, 4): # 1, 2, 3
+ my_fn(number)
+
+ # Flush
+ await mlogger.flush.call_one(global_step=0)
+
+ # Shutdown when done
+ await mlogger.shutdown.call_one()
+
+if __name__ == "__main__":
+ asyncio.run(main())
+```
+
+Output:
+```bash
+=== [GlobalReduce] - METRICS STEP 0 ===
+my_sum_metric: 6.0
+my_max_metric: 3.0
+my_mean_metric: 2.0
+```
+
+### 1.2 Track Performance: Timing and Memory
+
+Use `Tracer` for tracking durations and memory usage. Overhead is minimal, and GPU timing is non-blocking. Set `timer="gpu"` for kernel-level precision. Tracer leverages `record_metric` in the backend.
+
+```python
+from forge.observability.perf_tracker import Tracer
+import torch
+
+# ... Initialize logger (as shown in previous example)
+
+def my_fn():
+ a = torch.randn(1000, 1000, device="cuda")
+
+ t = Tracer(prefix="my_cuda_loop", track_memory=True, timer="gpu")
+ t.start()
+ for _ in range(3):
+ torch.mm(a, a)
+ t.step("my_metric_mm")
+ t.stop()
+
+# Accumulate metrics
+for _ in range(2):
+ my_fn()
+
+await mlogger.flush(global_step=0) # Flush and reset
+```
+
+Output:
+```bash
+=== [GlobalReduce] - METRICS STEP 0 ===
+my_cuda_loop/memory_delta_end_start_avg_gb: 0.015
+my_cuda_loop/memory_peak_max_gb: 0.042
+my_cuda_loop/my_metric_mm/duration_avg_s: 0.031
+my_cuda_loop/my_metric_mm/duration_max_s: 0.186
+my_cuda_loop/total_duration_avg_s: 0.094
+my_cuda_loop/total_duration_max_s: 0.187
+```
+
+For convenience, you can also use `Tracer` as a context manager or decorator:
+
+```python
+from forge.observability.perf_tracker import trace
+
+with trace(prefix="train_step", track_memory=True, timer="gpu") as t:
+ t.step("fwd")
+ loss = model(x)
+ t.step("bwd")
+ loss.backward()
+
+@trace(prefix="my_reward_fn", track_memory=False, timer="cpu")
+async def reward_fn(x): # Supports both sync/async functions
+ return 1.0 if x > 0 else 0.0
+```
+## 2. Logging Modes
+
+Defined per backend. You have three options:
+
+- **global_reduce**: N ranks = 1 chart. Reduces metrics across all ranks. Ideal for a single aggregated view (e.g., average loss chart).
+- **per_rank_reduce**: N ranks = N charts. Each rank reduces locally and logs to its own logger. Ideal for per-rank performance debugging (e.g., GPU utilization).
+- **per_rank_no_reduce**: N ranks = N charts. Each rank streams to its own logger without reduction. Ideal for real-time streams.
+
+Consider an example with an actor running on 2 replicas, each with 2 processes, for a total of 4 ranks. We will record the sum of the rank values. For example, rank_0 records 0, and rank_1 records 1.
+
+```python
+import asyncio
+
+from forge.controller.actor import ForgeActor
+from forge.observability import get_or_create_metric_logger, record_metric, Reduce
+from monarch.actor import current_rank, endpoint
+
+# Your distributed actor
+class MyActor(ForgeActor):
+ @endpoint
+ async def my_fn(self):
+ rank = current_rank().rank # 0 or 1 per replica
+ record_metric("my_sum_rank_metric", rank, Reduce.SUM) # <--- your metric
+
+async def main():
+ # Setup logger
+ mlogger = await get_or_create_metric_logger(process_name="Controller")
+ await mlogger.init_backends.call_one(
+ {"console": {"logging_mode": "global_reduce"}} # <--- Define logging_mode here
+ )
+
+ # Setup actor
+ service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False}
+ my_actor = await MyActor.options(**service_config).as_service()
+
+ # Accumulate metrics
+ for _ in range(2): # 2 steps
+ await my_actor.my_fn.fanout()
+
+ # Flush
+ await mlogger.flush.call_one(global_step=0) # Flush and reset
+
+if __name__ == "__main__":
+ asyncio.run(main())
+```
+
+Output when `"logging_mode": "global_reduce"`
+```bash
+=== [GlobalReduce] - METRICS STEP 0 ===
+my_sum_rank_metric: 4.0 # (0 + 1) * 2 steps * 2 replicas
+===============
+```
+
+Now, let’s set `"logging_mode": "per_rank_reduce"`:
+```bash
+# replica 1
+=== [MyActor_661W_r0] - METRICS STEP 0 ===
+my_sum_rank_metric: 0.0 # (rank_0) * 2 steps
+===============
+=== [MyActor_661W_r1] - METRICS STEP 0 ===
+my_sum_rank_metric: 2.0 # (rank_1) * 2 steps
+===============
+
+# replica 2
+=== [MyActor_wQ1g_r0] - METRICS STEP 0 ===
+my_sum_rank_metric: 0.0 # (rank_0) * 2 steps
+===============
+=== [MyActor_wQ1g_r1] - METRICS STEP 0 ===
+my_sum_rank_metric: 2.0 # (rank_1) * 2 steps
+===============
+```
+
+Finally, with `"logging_mode": "per_rank_no_reduce"`, we have a stream with no reduction:
+```bash
+[0] [MyActor-0/2] 2025-10-10 12:21:09 INFO my_sum_rank_metric: 0
+[0] [MyActor-0/2] 2025-10-10 12:21:09 INFO my_sum_rank_metric: 0
+[1] [MyActor-1/2] 2025-10-10 12:21:09 INFO my_sum_rank_metric: 1
+[1] [MyActor-1/2] 2025-10-10 12:21:09 INFO my_sum_rank_metric: 1
+[0] [MyActor-0/2] 2025-10-10 12:21:09 INFO my_sum_rank_metric: 0
+[0] [MyActor-0/2] 2025-10-10 12:21:09 INFO my_sum_rank_metric: 0
+[1] [MyActor-1/2] 2025-10-10 12:21:09 INFO my_sum_rank_metric: 1
+[1] [MyActor-1/2] 2025-10-10 12:21:09 INFO my_sum_rank_metric: 1
+```
+
+## 3. Using Multiple Backends
+
+For example, you can do `global_reduce` with Weights & Biases while using `per_rank_no_reduce` for debugging logs on the console.
+
+```python
+mlogger = await get_or_create_metric_logger(process_name="Controller")
+await mlogger.init_backends.call_one({
+ "console": {"logging_mode": "per_rank_no_reduce"},
+ "wandb": {"logging_mode": "global_reduce"}
+})
+```
+
+### 3.1 Adding a New Backend
+
+Extend `LoggerBackend` for custom logging, such as saving data to JSONL files, sending Slack notifications when a metric hits a threshold, or supporting tools like MLFlow or Grafana. After writing your backend, register it with `forge.observability.metrics.get_logger_backend_class`.
+
+```python
+# Example of a custom backend
+class ConsoleBackend(LoggerBackend):
+ def __init__(self, logger_backend_config: dict[str, Any]) -> None:
+ super().__init__(logger_backend_config)
+
+ async def init(self, process_name: str | None = None, *args, **kwargs) -> None:
+ self.process_name = process_name
+
+ async def log_batch(self, metrics: list[Metric], global_step: int, *args, **kwargs) -> None:
+ # Called on flush
+ print(self.process_name, metrics)
+
+ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
+ # Called on `record_metric` if "logging_mode": "per_rank_no_reduce"
+ print(metric)
+```
+
+## 4. Adding a New Reduce Type
+
+Metrics are accumulated each time `record_metric` is called. The following example implements the `Reduce.MEAN` accumulator. Users can extend this by adding custom reduce types, such as `WordCounterAccumulator` or `SampleAccumulator`, and registering them with `forge.observability.metrics.Reduce`. For details on how this is used, see `forge.observability.metrics.MetricCollector`.
+
+
+```python
+# Example of a custom reduce type
+class MeanAccumulator(MetricAccumulator):
+ def __init__(self, reduction: Reduce) -> None:
+ super().__init__(reduction)
+ self.sum = 0.0
+ self.count = 0
+ self.is_reset = True
+
+ def append(self, value: Any) -> None:
+ # Called after record_metric(key, value, reduce.TYPE)
+ v = float(value.item() if hasattr(value, "item") else value)
+ self.sum += v
+ self.count += 1
+
+ def get_value(self) -> float:
+ return self.sum / self.count if self.count > 0 else 0.0
+
+ def get_state(self) -> dict[str, Any]:
+ return {"reduction_type": self.reduction_type.value, "sum": self.sum, "count": self.count}
+
+ @classmethod
+ def get_reduced_value_from_states(cls, states: list[dict[str, Any]]) -> float:
+ # Useful for global reduce; called before flush
+ total_sum = sum(s["sum"] for s in states)
+ total_count = sum(s["count"] for s in states)
+ return total_sum / total_count if total_count > 0 else 0.0
+
+ def reset(self) -> None:
+ self.sum = 0.0
+ self.count = 0
+ self.is_reset = True
+```
+
+## 5. Behind the Scenes
+
+We have two main requirements:
+1. Metrics must be accumulated somewhere.
+2. Metrics must be collected from all ranks.
+
+To address #1, we use a `MetricCollector` per process to store state. For example, with 10 ranks, there are 10 `MetricCollector` instances. Within each rank, `MetricCollector` is a singleton, ensuring the same object is returned after the first call. This eliminates the need to pass dictionaries between functions.
+
+To address #2, we automatically spawn a `LocalFetcherActor` for each process mesh and register it with the `GlobalLoggingActor`. This allows the `GlobalLoggingActor` to know which processes to call, and each `LocalFetcherActor` can access the local `MetricCollector`. This spawning and registration occurs in `forge.controller.provisioner.py::get_proc_mesh`.
+
+The flow is generally:
+GlobalLoggingActor.method() -> per-procmesh LocalFetcherActor.method() -> per-rank MetricCollector.method() -> logger
+
+So you may ask: "what about the logging backends"? They live in two places:
+- In each MetricCollector if the backend is marked as per_rank.
+- In the GlobalLoggingActor if the backend is marked as global_reduce.
+
+In summary:
+1. One `GlobalLoggingActor` serves as the controller.
+2. For each process, `forge.controller.provisioner.py::get_proc_mesh` spawns a `LocalFetcherActor`, so N ranks = N `LocalFetcherActor` instances. These are registered with the `GlobalLoggingActor`.
+3. Each rank has a singleton `MetricCollector`, holding accumulated metrics and per_rank backends.
+4. Calling `record_metric(key, value, reduce_type)` stores metrics locally in the `MetricCollector`.
+5. When GlobalLoggingActor.flush() -> all LocalFetcherActor.flush() --> MetricCollector.flush()
diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py
index 52262eed5..988673e3c 100644
--- a/src/forge/observability/__init__.py
+++ b/src/forge/observability/__init__.py
@@ -10,39 +10,45 @@
LocalFetcherActor,
)
from .metrics import (
+ BackendRole,
ConsoleBackend,
- # Utility functions
- get_actor_name_with_rank,
get_logger_backend_class,
- # Backend classes
LoggerBackend,
+ LoggingMode,
MaxAccumulator,
MeanAccumulator,
- # Accumulator classes
+ Metric,
MetricAccumulator,
MetricCollector,
MinAccumulator,
record_metric,
Reduce,
reduce_metrics_states,
+ SampleAccumulator,
StdAccumulator,
SumAccumulator,
WandbBackend,
)
from .perf_tracker import trace, Tracer
+from .utils import get_proc_name_with_rank
__all__ = [
# Main API functions
"record_metric",
"reduce_metrics_states",
- "get_actor_name_with_rank",
"get_logger_backend_class",
"get_or_create_metric_logger",
# Performance tracking
"Tracer",
"trace",
+ # Data classes
+ "Metric",
+ "BackendRole",
# Enums
"Reduce",
+ "LoggingMode",
+ # Utility functions
+ "get_proc_name_with_rank",
# Actor classes
"GlobalLoggingActor",
"LocalFetcherActor",
@@ -59,4 +65,5 @@
"MaxAccumulator",
"MinAccumulator",
"StdAccumulator",
+ "SampleAccumulator",
]
diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py
index d67a66a83..f7b1ccedf 100644
--- a/src/forge/observability/metric_actors.py
+++ b/src/forge/observability/metric_actors.py
@@ -6,17 +6,31 @@
import asyncio
import logging
-from typing import Any, Dict, Optional
+import uuid
+from typing import Any, Union
-from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc
+from forge.controller.actor import ForgeActor
+from forge.env import FORGE_DISABLE_METRICS
from forge.observability.metrics import (
+ BackendRole,
get_logger_backend_class,
LoggerBackend,
+ LoggingMode,
MetricCollector,
+ Reduce,
reduce_metrics_states,
)
+from monarch.actor import (
+ context,
+ endpoint,
+ get_or_spawn_controller,
+ ProcMesh,
+ this_proc,
+)
+
+
logger = logging.getLogger(__name__)
_global_logger = None
@@ -24,54 +38,55 @@
async def get_or_create_metric_logger(
proc_mesh: ProcMesh | None = None,
+ process_name: str | None = None,
) -> "GlobalLoggingActor":
- """Initializes a LocalFetcherActor in the specified process mesh (or current process if None),
- if not already initialized, registers it with the GlobalLoggingActor and returns the
- GlobalLoggingActor instance.
+ """Spawns a LocalFetcherActor for the specified ProcMesh (if not already initialized),
+ registers it with the GlobalLoggingActor, and returns the GlobalLoggingActor.
- There are primarily two ways to use this function:
- 1. In the main process, call `get_or_create_metric_logger()` to get the global logger.
- 2. In service processes, call `get_or_create_metric_logger(proc_mesh)` to register the
- local fetcher with the global logger.
+ Usage:
+ 1. Main process: call `get_or_create_metric_logger()` to get the global logger
+ 2. Service spawning: call `get_or_create_metric_logger(proc_mesh, process_name)` to register the
+ map(proc_mesh,local fetcher) with the global logger, so it knows to broadcast to all ranks.
Args:
- proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None,
- uses `monarch.actor.this_proc()`.
+ proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None, uses `this_proc()`.
+ process_name: Optional process name (e.g., "TrainActor") for logging. Auto-detected from the context if None.
Returns:
GlobalLoggingActor: The global logging controller.
Raises:
- ValueError: If the logging state is inconsistent, i.e. the fetcher is already
- registered, but only in the process or the global logger.
+ ValueError: If the logging state is inconsistent.
Example:
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.metrics import record_metric
# Main process setup
- mlogger = await get_or_create_metric_logger()
+ mlogger = await get_or_create_metric_logger(process_name="Controller")
# Initialize logging backends
await mlogger.init_backends({
- "console": {"reduce_across_ranks": True},
- "wandb": {"project": "my_project", "reduce_across_ranks": False}
+ "console": {"logging_mode": "global_reduce"},
+ "wandb": {"project": "my_project", "logging_mode": "per_rank_reduce"}
})
# Initialize services...
- policy = await Policy.as_service(...)
+ policy = await Generator.as_service(...)
# Training loop
for step in range(max_steps):
- record_metric("loss", 1.2, step, reduction_type=Reduce.MEAN)
+ record_metric("loss", 1.2, reduction_type=Reduce.MEAN)
# ... training code with record_metric() calls ...
- await mlogger.flush(step) # Log metrics for this step
+ await mlogger.flush.call_one(step) # Log metrics for this step
# Shutdown
- await mlogger.shutdown()
+ await mlogger.shutdown.call_one()
"""
+
# Get or create the singleton global logger
global _global_logger
+
if _global_logger is None:
_global_logger = await get_or_spawn_controller(
"global_logger", GlobalLoggingActor
@@ -81,9 +96,15 @@ async def get_or_create_metric_logger(
# Determine process context
proc = proc_mesh if proc_mesh is not None else this_proc()
+ # Auto-detect process_name from proc mesh if not provided
+ if process_name is None:
+ ctx = context()
+ process_name = ctx.actor_instance.actor_id.actor_name
+
# Check current state for consistency
proc_has_local_fetcher = hasattr(proc, "_local_fetcher")
- global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc)
+ proc_id = proc._uid if proc_has_local_fetcher else None
+ global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc_id)
# Consistency check: both should be in sync
if proc_has_local_fetcher != global_logger_has_local_fetcher:
@@ -95,173 +116,276 @@ async def get_or_create_metric_logger(
f"Both should be True (already setup) or both False (needs setup)."
)
- # Setup local_fetcher_actor if needed
- if not proc_has_local_fetcher:
+ # Setup local_fetcher_actor if needed (unless disabled by environment flag)
+ if not proc_has_local_fetcher and not FORGE_DISABLE_METRICS.get_value():
local_fetcher_actor = proc.spawn(
- "local_fetcher_actor", LocalFetcherActor, global_logger
+ "local_fetcher_actor", LocalFetcherActor, global_logger, process_name
)
- await global_logger.register_fetcher.call_one(local_fetcher_actor, proc)
- proc._local_fetcher = local_fetcher_actor
+ # Generate a unique ID to map procmesh to fetcher
+ proc._uid = str(uuid.uuid4())
+ proc._local_fetcher = local_fetcher_actor # pyre-ignore
+
+ await global_logger.register_fetcher.call_one(local_fetcher_actor, proc._uid)
return global_logger
-class LocalFetcherActor(Actor):
- """Thin per-process actor used to trigger MetricCollector singleton
- operations without direct access. It is what GlobalLoggingActor
- uses to broadcast inits/flushes across ranks.
+class LocalFetcherActor(ForgeActor):
+ """Actor spawned once per ProcMesh that, when called, runs on every rank in that ProcMesh
+ and accesses each rank's local MetricCollector.
- GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector
+ Flow:
+ GlobalLoggingActor.method() -> per-procmesh LocalFetcherActor.method() -> per-rank MetricCollector.method() -> logger
"""
- def __init__(self, global_logger: Optional["GlobalLoggingActor"] = None) -> None:
+ def __init__(
+ self,
+ global_logger: Union["GlobalLoggingActor", None] = None,
+ process_name: str | None = None,
+ ) -> None:
self.global_logger = global_logger
- _is_initialized = False
+ self.process_name = process_name
@endpoint
async def flush(
- self, step: int, return_state: bool = False
- ) -> Dict[str, Dict[str, Any]]:
+ self, global_step: int, return_state: bool = False
+ ) -> dict[str, dict[str, Any]]:
"""Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True.
This should only ever be called by the global logger.
Args:
- step (int): train step used by backends to align all metrics on the same x-axis
+ global_step (int): step used by backends to align all metrics on the same x-axis
return_state (bool): Used by GlobalLoggingActor for reduction across all ranks.
If False, returns empty dict, else returns the state of all metrics collected.
Returns:
- Dict[str, Dict[str, Any]]: Dict of {metric_key: metric_state},
+ dict[str, dict[str, Any]]: of {metric_key: metric_state},
e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}.
"""
collector = MetricCollector()
- result = await collector.flush(step, return_state=return_state)
+ result = await collector.flush(global_step, return_state=return_state)
return result
@endpoint
async def init_backends(
self,
- metadata_per_primary_backend: Dict[str, Dict[str, Any]],
- config: Dict[str, Any],
- ):
- """Init local (per-rank) logger backends and MetricCollector."""
+ metadata_per_controller_backend: dict[str, dict[str, Any]],
+ backend_config: dict[str, Any],
+ run_config: dict[str, Any] | None = None,
+ global_step: int = 0,
+ ) -> None:
+ """Init per-rank logger backends and MetricCollector.
+
+ Args:
+ metadata_per_controller_backend (dict[str, dict[str, Any]]): Metadata from controller backends for shared state.
+ backend_config (dict[str, Any]): Backend configurations with logging modes and settings.
+ run_config (dict[str, Any] | None): Your application's configuration
+ (hyperparameters, dataset, model settings) to log to backends for
+ experiment tracking.
+ global_step (int): Initial step for metrics.
+ """
collector = MetricCollector()
- await collector.init_backends(metadata_per_primary_backend, config)
+ await collector.init_backends(
+ metadata_per_controller_backend,
+ backend_config,
+ global_step,
+ process_name=self.process_name,
+ run_config=run_config,
+ )
@endpoint
- async def shutdown(self):
-
+ async def shutdown(self) -> None:
collector = MetricCollector()
await collector.shutdown()
-class GlobalLoggingActor(Actor):
- """Coordinates metric logging across all ranks for every training step.
+class GlobalLoggingActor(ForgeActor):
+ """Coordinates metric logging across all ProcMeshes and their ranks.
Supports multiple logging backends (e.g., WandB, TensorBoard, etc.),
- for per-rank and/or global reduction logging modes.
-
- If a backend config has flag `reduce_across_ranks=False`, an instance of the backend
- is initialized per-rank, otherwise it is done once globally.
+ with per-rank and/or global reduction logging modes.
This GlobalLoggingActor should be spawned once in the controller. A LocalFetcherActor
- is automatically spawned per-rank in `forge.controller.provisioner.py` and registered
+ is automatically spawned per-procmesh in `forge.controller.provisioner.py` and registered
with this actor. The LocalFetcherActor is responsible for instantiating
- the per-rank MetricCollector.
+ the per-rank MetricCollector and working as a bridge between GlobalLoggingActor and processes.
- In summary, the flow is:
- - GlobalLoggingActor init_backends() -> LocalFetcherActor init_backends() -> per-rank MetricCollector
- - GlobalLoggingActor flush() -> LocalFetcherActor flush() -> per-rank MetricCollector flush
+ Flow:
+ GlobalLoggingActor.method() -> per-procmesh LocalFetcherActor.method() -> per-rank MetricCollector.method() -> logger
"""
def __init__(self):
- self.fetchers: Dict[str, LocalFetcherActor] = {}
- self.config: Dict[str, Any] | None = None
- self.global_logger_backends: Dict[str, LoggerBackend] = {}
- self.metadata_per_primary_backend: Dict[str, Dict[str, Any]] = {}
+ self.fetchers: dict[str, LocalFetcherActor] = {}
+ self.config: dict[str, Any] | None = None
+ self.run_config: dict[str, Any] | None = None
+ self.global_logger_backends: dict[str, LoggerBackend] = {}
+ self.metadata_per_controller_backend: dict[str, dict[str, Any]] = {}
+
+ def _validate_backend_config(
+ self, backend_name: str, config: dict[str, Any]
+ ) -> dict[str, Any]:
+ """Validate and normalize backend configuration."""
+ if "logging_mode" not in config:
+ raise ValueError(
+ f"logging_mode is required for backend '{backend_name}' but was not provided. "
+ f"Please specify a logging_mode in your config. "
+ f"See forge.observability.metrics.LoggingMode for available options: "
+ f"{', '.join([mode.value for mode in LoggingMode])}."
+ )
- @endpoint
- async def init_backends(self, config: Dict[str, Any]):
- """
- Sets config in global actor, so other actors can get it, then eagerly initializes backend and MetricCollectors
- in all registered fetchers.
+ # Convert string to LoggingMode enum
+ mode_value = config["logging_mode"]
+ if isinstance(mode_value, str):
+ mode = LoggingMode(mode_value)
+ elif isinstance(mode_value, LoggingMode):
+ mode = mode_value
+ else:
+ raise TypeError(
+ f"logging_mode must be str or LoggingMode enum, got {type(mode_value)}"
+ )
+
+ # Validate per_rank_share_run configuration
+ share_run = config.get("per_rank_share_run", False)
+ if mode == LoggingMode.GLOBAL_REDUCE and share_run:
+ logger.warning(
+ f"{backend_name}: per_rank_share_run=True is ignored in {mode.value} mode. "
+ "Setting it to False."
+ )
+ share_run = False
+
+ # WandB-specific warning for suboptimal configuration
+ if (
+ backend_name == "wandb"
+ and mode == LoggingMode.PER_RANK_REDUCE
+ and share_run
+ ):
+ logger.warning(
+ "WandB: Using 'per_rank_reduce' with 'per_rank_share_run=True' is not recommended. "
+ "This configuration can lead to confusing metrics where reduced values from multiple ranks "
+ "are written to the same run/step, displaying only one of them. Consider either:\n"
+ " 1. Set 'per_rank_share_run=False' to create separate runs per rank, OR\n"
+ " 2. Use 'per_rank_no_reduce' for real-time streaming to a shared run"
+ )
- A backend is always initialized in the controller (primary backend) and can be used as a logger or as a source
- for metadata to be shared with per-rank backends, e.g. shared run IDs for wandb.
+ return {
+ **config,
+ "logging_mode": mode,
+ "per_rank_share_run": share_run,
+ }
- The backend instantiation is controlled by the backend config flag `reduce_across_ranks`: if False,
- a per-rank backend is initialized, i.e. if there are 2 ranks, each will have its own backend,
- and will log independently, i.e. each rank will have its own run in wandb.
+ @endpoint
+ async def init_backends(
+ self, backend_config: dict[str, Any], run_config: dict[str, Any] | None = None
+ ) -> None:
+ """Sets config in global actor and initializes existing backends and collectors. Later spawned actors
+ are initialized in `register_fetcher` endpoint.
- Else, if True, the GlobalLoggingActor will fetch all local metrics collectors to get their states
- and reduce them to a single value, which will be logged by the primary backend in this controller.
+ Controller backends (instantiated in the controller) can provide metadata to be shared with rank backends,
+ e.g. shared run IDs for WandB. For details on logging modes, see `forge.observability.metrics.LoggingMode`.
Args:
- config (Dict[str, Any]): Config for metric logging where keys are backend names,
- e.g. {"console": {"reduce_across_ranks": True}, "wandb": {"reduce_across_ranks": False}}
+ backend_config (dict[str, Any]): Config for metric logging where keys are backend names.
+ Each backend config supports:
+ - logging_mode (str | LoggingMode): Check LoggingMode for options. Defaults to "global_reduce".
+ - per_rank_share_run (bool, default False): For per-rank modes only. Whether ranks
+ share a single run/logger instance. Ignored for "global_reduce" mode.
+ - Additional backend-specific options (e.g., "project" for WandB)
+
+ Example:
+ {
+ "console": {"logging_mode": "global_reduce"},
+ "wandb": {
+ "logging_mode": "per_rank_no_reduce",
+ "per_rank_share_run": True,
+ "project": "my_project",
+ }
+ }
+ run_config (dict[str, Any] | None): Your application's configuration
+ (hyperparameters, dataset, model settings) to log to backends for
+ experiment tracking.
+
+ Raises:
+ ValueError: If backend config is invalid or missing required fields.
"""
- self.config = config
+ self.config = {}
+ self.run_config = run_config
+
+ # Skip initialization if disabled by environment flag
+ if FORGE_DISABLE_METRICS.get_value():
+ return
+
+ # Validate and normalize each backend config
+ for backend_name, cfg in backend_config.items():
+ self.config[backend_name] = self._validate_backend_config(backend_name, cfg)
+
+ # Initialize backends based on logging mode
+ for backend_name, backend_config in self.config.items():
+ mode = backend_config["logging_mode"]
- for backend_name, backend_config in config.items():
- backend = get_logger_backend_class(backend_name)(backend_config)
- await backend.init(role="global")
+ backend: LoggerBackend = get_logger_backend_class(backend_name)(
+ **backend_config
+ )
+ await backend.init(
+ role=BackendRole.GLOBAL,
+ process_name="global_reduce",
+ run_config=self.run_config,
+ )
- # Extract metadata from primary logger to be shared with secondary loggers
- # and store it
- reduce_across_ranks = backend_config.get("reduce_across_ranks", True)
- if not reduce_across_ranks:
- primary_backend_metadata = (
+ # Extract metadata from controller logger to be shared with per-rank loggers
+ if mode != LoggingMode.GLOBAL_REDUCE:
+ controller_metadata: dict[str, Any] = (
backend.get_metadata_for_secondary_ranks() or {}
)
- self.metadata_per_primary_backend[
- backend_name
- ] = primary_backend_metadata
+ self.metadata_per_controller_backend[backend_name] = controller_metadata
- # Store global logger backends
- if reduce_across_ranks:
+ # Store global logger backends for later flush
+ if mode == LoggingMode.GLOBAL_REDUCE:
self.global_logger_backends[backend_name] = backend
- # Eager init collectors on all registered fetchers in parallel, passing primary states and config
+ # Init collectors on all registered fetchers
if self.fetchers:
tasks = [
fetcher.init_backends.call(
- self.metadata_per_primary_backend, self.config
+ self.metadata_per_controller_backend, self.config, self.run_config
)
for fetcher in self.fetchers.values()
]
await asyncio.gather(*tasks, return_exceptions=True)
@endpoint
- async def register_fetcher(self, fetcher: LocalFetcherActor, name: str | ProcMesh):
- """Registers a fetcher with the global actor. Each key represents a process mesh.
- If there are 2 processes, each with 2 replicas with N gpus, we would
- have 4 keys, i.e. 2 proces meshes, each with 2 replicas."""
- self.fetchers[name] = fetcher # pyre-ignore
+ async def register_fetcher(self, fetcher: LocalFetcherActor, proc_id: str) -> None:
+ """Registers a LocalFetcherActor with the GlobalLoggingActor. One LocalFetcherActor per ProcMesh.
+
+ Args:
+ fetcher: The LocalFetcherActor instance for a ProcMesh
+ proc_id: Unique identifier for the ProcMesh
+ """
+ self.fetchers[proc_id] = fetcher
# Self-init for respawned actors
if self.config:
- logger.debug(f"Initializing new LocalFetcherActor {name}")
+ logger.debug(f"Initializing new LocalFetcherActor for proc_id={proc_id}")
await fetcher.init_backends.call(
- self.metadata_per_primary_backend, self.config
+ self.metadata_per_controller_backend, self.config, self.run_config
)
@endpoint
- async def deregister_fetcher(self, name: str | ProcMesh):
- if name not in self.fetchers:
+ async def deregister_fetcher(self, proc_id: str) -> None:
+ if proc_id not in self.fetchers:
logger.warning(
- f"Fetcher {name} not registered in GlobalLoggingActor. Cannot deregister."
+ f"Fetcher {proc_id} not registered in GlobalLoggingActor. Cannot deregister."
f"Available fetchers: {self.fetchers.keys()}"
)
return
- del self.fetchers[name]
+ del self.fetchers[proc_id]
@endpoint
- async def flush(self, step: int):
+ async def flush(self, global_step: int) -> None:
"""
Triggers parallel flush/reset on all registered fetchers. Per-rank MetricCollectors
log to local backends and return states if needed for cross-rank reduction.
Args:
- step (int): Global step for logging.
+ global_step (int): step for logging.
"""
if not self.fetchers:
return
@@ -269,74 +393,97 @@ async def flush(self, step: int):
config = self.config
if config is None:
logger.warning(
- "GlobalLoggingActor flush() called before init_backends(). "
- "No backends will be flushed."
+ "Cannot flush collected metrics. GlobalLoggingActor.flush() called before init_backends()."
+ " No backends will be flushed. Please call in your main file:\n"
+ "`mlogger = await get_or_create_metric_logger(process_name='Controller')`\n"
+ "`await mlogger.init_backends.call_one(logging_config)`\n"
)
return
- # if reduce_across_ranks=True, we need to reduce the states from all ranks
- # and log with the primary backend
- requires_reduce = any(
- backend_config.get("reduce_across_ranks", True)
+
+ # Check if need to collect states from fetchers for global reduction
+ needs_state_collection = any(
+ backend_config["logging_mode"] == LoggingMode.GLOBAL_REDUCE
for backend_config in config.values()
)
- logger.debug(f"Global flush for step {step}: {len(self.fetchers)} fetchers")
+ logger.debug(
+ f"Global flush for global step {global_step}: {len(self.fetchers)} fetchers"
+ )
# Broadcast flush to all fetchers
results = await asyncio.gather(
*[
- f.flush.call(step, return_state=requires_reduce)
+ f.flush.call(global_step, return_state=needs_state_collection)
for f in self.fetchers.values()
],
return_exceptions=True,
)
- if requires_reduce:
- # Handle exceptions and extract values from ValueMesh results
- all_local_states = []
- for result in results:
- if isinstance(result, BaseException):
- logger.warning(f"Flush failed on a fetcher: {result}")
- continue
-
- # result is a generator that outputs a pair [{'gpus': i/N}, {metric_key1: metric_state1, ...}}]
- for gpu_info, local_metric_state in result.items():
- if isinstance(local_metric_state, dict):
- all_local_states.append(local_metric_state)
- else:
- logger.warning(
- f"Unexpected result from fetcher. {gpu_info=}, {local_metric_state=}"
- )
+ if needs_state_collection:
+
+ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]:
+ all_local_states = []
+ for result in results:
+ if isinstance(result, BaseException):
+ logger.warning(f"Flush failed on a fetcher: {result}")
+ continue
+
+ # result is a generator that outputs a pair [{'gpus': i/N}, {metric_key1: metric_state1, ...}}]
+ for gpu_info, local_metric_state in result.items():
+ if isinstance(local_metric_state, dict):
+ all_local_states.append(local_metric_state)
+ else:
+ logger.warning(
+ f"Unexpected result from fetcher. {gpu_info=}, {local_metric_state=}"
+ )
+ return all_local_states
+
+ all_local_states = extract_values_from_valuemesh(results)
if not all_local_states:
- logger.warning(f"No states to reduce for step {step}")
+ logger.warning(f"No states to reduce for global_step {global_step}")
return
- # Reduce
+ # Reduce metrics from states
reduced_metrics = reduce_metrics_states(all_local_states)
- # Log to each global logger_backend
- for (
- logger_backend_name,
- logger_backend,
- ) in self.global_logger_backends.items():
- await logger_backend.log(reduced_metrics, step)
+ # Split into scalar metrics and sample metrics
+ scalar_metrics = [
+ m for m in reduced_metrics if m.reduction != Reduce.SAMPLE
+ ]
+ sample_metrics = [
+ m for m in reduced_metrics if m.reduction == Reduce.SAMPLE
+ ]
+
+ # Log to global backends
+ for backend_name, backend in self.global_logger_backends.items():
+ if scalar_metrics:
+ await backend.log_batch(scalar_metrics, global_step)
+ if sample_metrics:
+ await backend.log_samples(sample_metrics, global_step)
@endpoint
- def has_fetcher(self, name: str | ProcMesh) -> bool:
- """Check if a fetcher is registered with the given name."""
- return name in self.fetchers
+ async def has_fetcher(self, proc_id: str) -> bool:
+ """Check if a fetcher is registered with the given proc_id."""
+ return proc_id in self.fetchers
@endpoint
- def get_fetcher_count(self) -> int:
+ async def get_fetcher_count(self) -> int:
return len(self.fetchers)
@endpoint
- async def shutdown(self):
- # Finish per-rank logger_backends via fetchers
+ async def shutdown(self) -> None:
if self.fetchers:
- tasks = [fetcher.shutdown.call() for fetcher in self.fetchers.values()]
- await asyncio.gather(*tasks, return_exceptions=True)
+ try:
+ tasks = [fetcher.shutdown.call() for fetcher in self.fetchers.values()]
+ await asyncio.wait_for(
+ asyncio.gather(*tasks, return_exceptions=True), timeout=2.0
+ )
+ except asyncio.TimeoutError:
+ logger.warning(
+ "Metric logging fetcher shutdown timed out likely due to the child process being terminated before the parent."
+ )
+
# Finish global logger_backends
for logger_backend_name, logger_backend in self.global_logger_backends.items():
await logger_backend.finish()
diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py
index 990a301e0..2c5a4663f 100644
--- a/src/forge/observability/metrics.py
+++ b/src/forge/observability/metrics.py
@@ -4,16 +4,64 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+import asyncio
+import heapq
+import itertools
+import json
import logging
-
import os
+import time
from abc import ABC, abstractmethod
+from dataclasses import dataclass
from enum import Enum
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List
+
+from forge.observability.utils import get_proc_name_with_rank
+
+from forge.util.logging import get_logger, log_once
+from monarch.actor import current_rank
+
+logger = get_logger("INFO")
+
+
+class BackendRole(Enum):
+ """Backend role constants for metric logging actors.
+
+ Defines whether an actor operates as a local (per-rank) or global (controller) role
+ in the distributed metrics collection system.
+ """
+
+ LOCAL = "local"
+ GLOBAL = "global"
+
+
+class LoggingMode(Enum):
+ """Metric logging behavior for distributed training scenarios.
-from monarch.actor import context, current_rank
+ Each mode serves different observability needs:
-logger = logging.getLogger(__name__)
+ GLOBAL_REDUCE = "global_reduce"
+ Best for: Metrics that are best visualized as a single value per step.
+ Behavior: All ranks accumulate → controller reduces → single log entry
+ Example use: 8 ranks training, want 1 loss value per training step averaged across all
+ Where: GlobalLoggingActor logs reduced values to backends on flush.
+
+ PER_RANK_REDUCE = "per_rank_reduce"
+ Best for: Per-rank performance metrics, debugging individual rank behavior
+ Behavior: Each rank accumulates + logs its own reduced values
+ Example use: Monitor GPU utilization per rank, get 8 separate log entries per step
+ Where: MetricCollector on each rank log reduced values to backends on flush.
+
+ PER_RANK_NO_REDUCE = "per_rank_no_reduce"
+ Best for: Real-time streaming, time-series debugging
+ Behavior: Raw values logged immediately on record_metric() calls. Ignores reduce type.
+ Example use: See what every rank is doing in real time.
+ Where: MetricCollector on each rank log raw values to backends on push.
+ """
+
+ GLOBAL_REDUCE = "global_reduce"
+ PER_RANK_REDUCE = "per_rank_reduce"
+ PER_RANK_NO_REDUCE = "per_rank_no_reduce"
class Reduce(Enum):
@@ -22,6 +70,7 @@ class Reduce(Enum):
MAX = "max"
MIN = "min"
STD = "std"
+ SAMPLE = "sample"
@property
def accumulator_class(self):
@@ -31,115 +80,101 @@ def accumulator_class(self):
Reduce.MAX: MaxAccumulator,
Reduce.MIN: MinAccumulator,
Reduce.STD: StdAccumulator,
+ Reduce.SAMPLE: SampleAccumulator,
}
return mapping[self]
-def get_actor_name_with_rank() -> str:
- """
- Extracts actor information from Monarch context to form a logging name.
+@dataclass
+class Metric:
+ """Container for metric data including key, value, reduction type, and timestamp.
- Returns:
- str: Format "ActorName_replicaId_rLocalRank" (e.g., "TrainActor_abcd_r0").
- Falls back to "UnknownActor" if context unavailable.
+ Timestamp is automatically set to current time if not provided.
"""
- # Add more defensive checks
- ctx = context()
- if ctx is None or ctx.actor_instance is None:
- logger.warning("Context unavailable, using fallback actor name for logging.")
- return "UnknownActor"
-
- actor_instance = ctx.actor_instance
- rank = current_rank()
-
- actor_id_full = str(actor_instance.actor_id)
-
- # Parse the actor_id
- parts = actor_id_full.split(".")
- rank_name = "UnknownActor" # fallback
- if len(parts) >= 2:
- world_part = parts[0] # e.g., "_1rjutFUXQrEJ[0]"
- actor_part = parts[1] # e.g., "TestActorConfigured[0]"
-
- # Extract world ID and proc rank
- world_id = world_part.split("[")[0] if "[" in world_part else world_part
-
- # Extract clean actor name (remove "Configured" suffix if present)
- if "[" in actor_part:
- actor_name = actor_part.split("[")[0] # e.g., "TestActorConfigured"
- if actor_name.endswith("Configured"):
- actor_name = actor_name[:-10] # Remove "Configured"
- else:
- actor_name = actor_part
-
- # Use last 4 characters of world_id as replica identifier
- # This is deterministic, readable, and works for any number of replicas
- replica_id = world_id[-4:] if len(world_id) >= 4 else world_id
-
- # Use current_rank().rank as the local rank within the replica
- local_rank = rank.rank
- rank_name = f"{actor_name}_{replica_id}_r{local_rank}"
+ key: str
+ value: Any
+ reduction: Reduce
+ timestamp: float | None = None
- return rank_name
+ def __post_init__(self):
+ if self.timestamp is None:
+ self.timestamp = time.time()
def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None:
- """
- Records a metric value for later reduction and logging.
+ """Thin wrapper to send metrics to per-rank local MetricCollectors.
Relies on a per-rank MetricCollector singleton for ease of use, i.e.
call `record_metric` anywhere in the code without moving the
collector from function to function.
- The collector methods are triggered per-rank by a
- `forge.observability.metric_actors.LocalFetcherActor`, instantiated
- during actor initialization.
-
- Records are flushed when `forge.observability.metric_actors.GlobalLoggingActor.flush()`
- is called, typically triggered by the training loop at regular intervals.
-
Can be disabled globally by setting the environment variable `FORGE_DISABLE_METRICS=true`.
+
+ Collected metrics are flushed to backends on flush(), generally:
+ GlobalLoggingActor.method() -> per-procmesh LocalFetcherActor.method() -> per-rank MetricCollector.method() -> logger
"""
- # Skip metrics collection if disabled for tests
+ # Skip metrics collection
if os.getenv("FORGE_DISABLE_METRICS", "false").lower() == "true":
return
+ # timestamp is added automatically by the Metric class
+ metric = Metric(key=key, value=value, reduction=reduction)
collector = MetricCollector()
- collector.push(key, value, reduction)
+ collector.push(metric)
-def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, Any]:
- """Reduce metric accumulators states to a single value per metric.
+def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metric]:
+ """Reduce metric accumulators states to a list of Metrics.
Can be used when reducing metrics across ranks or services, as merging
states is more precise than merging locally reduced metrics.
Args:
- states (List[Dict[str, Dict[str, Any]]]): List of state of one or more metrics,
+ states (list[dict[str, dict[str, Any]]]): List of state of one or more metrics,
normally retrieved using `forge.observability.metrics.MetricAccumulator.get_state()`.
Returns:
- Dict[str, Any]: Dictionary with format {metric_key: reduced_value}
+ list[Metric]: List of reduced metrics
Example:
- states = [
- {"loss": {"count": 5, "sum": 14, "reduction_type": Reduce.MEAN}},
- {"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}},
+ >>> states = [
+ ... {
+ ... "loss": {"count": 5, "sum": 14, "reduction_type": Reduce.MEAN},
+ ... "reward/sample": {
+ ... "reduction_type": Reduce.Sample,
+ ... "samples": [{"episode_id": 1, "reward": 0.5}],
+ ... },
+ ... },
+ ... {
+ ... "loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN},
+ ... "reward/sample": {
+ ... "reduction_type": Reduce.Sample,
+ ... "samples": [{"episode_id": 2, "reward": 1.0}],
+ ... },
+ ... },
+ ... ]
+ >>> reduce_metrics_states(states)
+ [
+ Metric(key='loss', value=2.0, reduction=Reduce.MEAN), # (14 + 16) / (5 + 10) = 2.0
+ Metric(
+ key='reward/sample',
+ value=[{'episode_id': 1, 'reward': 0.5}, {"episode_id": 2, "reward": 1.0}],
+ reduction=Reduce.SAMPLE,
+ )
]
- reduce_metrics_states(states)
- >>> {"loss": 2.0}
Raises:
ValueError: on mismatched reduction types for the same metric key.
"""
if not states:
- return {}
+ return []
# Collect unique keys across all
all_keys = set(k for state in states for k in state)
- reduced_metrics = {}
+ # For each metric key, reduce the states
+ reduced_metrics = []
for key in all_keys:
metric_states = [state.get(key) for state in states if key in state]
if not metric_states:
@@ -158,7 +193,14 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str,
metric_accumulator = Reduce(first_reduction_type).accumulator_class
reduced_value = metric_accumulator.get_reduced_value_from_states(metric_states)
- reduced_metrics[key] = reduced_value
+
+ # Create Metric object with reduced value
+ metric = Metric(
+ key=key,
+ value=reduced_value,
+ reduction=Reduce(first_reduction_type),
+ )
+ reduced_metrics.append(metric)
return reduced_metrics
@@ -171,8 +213,9 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str,
class MetricAccumulator(ABC):
"""Every metric maps to a MetricAccumulator, which accumulates values and optionally reduces them."""
- def __init__(self, reduction: Reduce):
+ def __init__(self, reduction: Reduce) -> None:
self.reduction_type = reduction
+ self.is_reset = True
@abstractmethod
def append(self, value: Any) -> None:
@@ -185,13 +228,13 @@ def get_value(self) -> Any:
pass
@abstractmethod
- def get_state(self) -> Dict[str, Any]:
+ def get_state(self) -> dict[str, Any]:
"""Returns serializable state for cross-rank merge (e.g., {'sum': 10.0, 'count': 5})."""
pass
@classmethod
@abstractmethod
- def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> Any:
+ def get_reduced_value_from_states(cls, states: list[dict[str, Any]]) -> Any:
"""Merges states from multiple ranks into single reduced value (e.g., total_sum/total_count for MEAN)."""
pass
@@ -202,20 +245,22 @@ def reset(self) -> None:
class MeanAccumulator(MetricAccumulator):
- def __init__(self, reduction: Reduce):
+ def __init__(self, reduction: Reduce) -> None:
super().__init__(reduction)
self.sum = 0.0
self.count = 0
+ self.is_reset = True
def append(self, value: Any) -> None:
v = float(value.item() if hasattr(value, "item") else value)
+ self.is_reset = False
self.sum += v
self.count += 1
def get_value(self) -> float:
return self.sum / self.count if self.count > 0 else 0.0
- def get_state(self) -> Dict[str, Any]:
+ def get_state(self) -> dict[str, Any]:
return {
"reduction_type": self.reduction_type.value,
"sum": self.sum,
@@ -223,94 +268,112 @@ def get_state(self) -> Dict[str, Any]:
}
@classmethod
- def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float:
+ def get_reduced_value_from_states(cls, states: list[dict[str, Any]]) -> float:
total_sum = sum(s["sum"] for s in states)
total_count = sum(s["count"] for s in states)
return total_sum / total_count if total_count > 0 else 0.0
def reset(self) -> None:
+ self.is_reset = True
self.sum = 0.0
self.count = 0
class SumAccumulator(MetricAccumulator):
- def __init__(self, reduction: Reduce):
+ def __init__(self, reduction: Reduce) -> None:
super().__init__(reduction)
self.total = 0.0
+ self.is_reset = True
def append(self, value: Any) -> None:
v = float(value.item() if hasattr(value, "item") else value)
+ self.is_reset = False
self.total += v
def get_value(self) -> float:
return self.total
- def get_state(self) -> Dict[str, Any]:
+ def get_state(self) -> dict[str, Any]:
return {"reduction_type": self.reduction_type.value, "total": self.total}
@classmethod
- def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float:
+ def get_reduced_value_from_states(cls, states: list[dict[str, Any]]) -> float:
return sum(s["total"] for s in states)
def reset(self) -> None:
+ self.is_reset = True
self.total = 0.0
class MaxAccumulator(MetricAccumulator):
- def __init__(self, reduction: Reduce):
+ def __init__(self, reduction: Reduce) -> None:
super().__init__(reduction)
self.max_val = float("-inf")
+ self.is_reset = True
def append(self, value: Any) -> None:
v = float(value.item() if hasattr(value, "item") else value)
+ self.is_reset = False
self.max_val = max(self.max_val, v)
def get_value(self) -> float:
return self.max_val
- def get_state(self) -> Dict[str, Any]:
- return {"reduction_type": self.reduction_type.value, "max_val": self.max_val}
+ def get_state(self) -> dict[str, Any]:
+ return {
+ "reduction_type": self.reduction_type.value,
+ "max_val": self.max_val,
+ }
@classmethod
- def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float:
+ def get_reduced_value_from_states(cls, states: list[dict[str, Any]]) -> float:
return max(s["max_val"] for s in states)
def reset(self) -> None:
+ self.is_reset = True
self.max_val = float("-inf")
class MinAccumulator(MetricAccumulator):
- def __init__(self, reduction: Reduce):
+ def __init__(self, reduction: Reduce) -> None:
super().__init__(reduction)
self.min_val = float("inf")
+ self.is_reset = True
def append(self, value: Any) -> None:
v = float(value.item() if hasattr(value, "item") else value)
+ self.is_reset = False
self.min_val = min(self.min_val, v)
def get_value(self) -> float:
return self.min_val
- def get_state(self) -> Dict[str, Any]:
- return {"reduction_type": self.reduction_type.value, "min_val": self.min_val}
+ def get_state(self) -> dict[str, Any]:
+ return {
+ "reduction_type": self.reduction_type.value,
+ "min_val": self.min_val,
+ }
@classmethod
- def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float:
+ def get_reduced_value_from_states(cls, states: list[dict[str, Any]]) -> float:
return min(s["min_val"] for s in states)
def reset(self) -> None:
+ self.is_reset = True
self.min_val = float("inf")
class StdAccumulator(MetricAccumulator):
- def __init__(self, reduction: Reduce):
+ def __init__(self, reduction: Reduce) -> None:
super().__init__(reduction)
self.sum = 0.0
self.sum_sq = 0.0
self.count = 0
+ self.is_reset = True
def append(self, value: Any) -> None:
v = float(value.item() if hasattr(value, "item") else value)
+ self.is_reset = False
self.sum += v
self.sum_sq += v * v
self.count += 1
@@ -324,7 +387,7 @@ def get_value(self) -> float:
variance = (self.sum_sq / self.count) - (mean * mean)
return max(0.0, variance) ** 0.5
- def get_state(self) -> Dict[str, Any]:
+ def get_state(self) -> dict[str, Any]:
return {
"reduction_type": self.reduction_type.value,
"sum": self.sum,
@@ -333,7 +396,7 @@ def get_state(self) -> Dict[str, Any]:
}
@classmethod
- def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float:
+ def get_reduced_value_from_states(cls, states: list[dict[str, Any]]) -> float:
total_sum = sum(s["sum"] for s in states)
total_sum_sq = sum(s["sum_sq"] for s in states)
total_count = sum(s["count"] for s in states)
@@ -346,11 +409,89 @@ def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float:
return max(0.0, variance) ** 0.5
def reset(self) -> None:
+ self.is_reset = True
self.sum = 0.0
self.sum_sq = 0.0
self.count = 0
+class SampleAccumulator(MetricAccumulator):
+ """Accumulator for sample-level metrics with top-k and bottom-k filtering.
+
+ Keeps the top-k and bottom-k samples by a given key (e.g., reward).
+ Useful for logging only the best and worst samples from a batch.
+
+ **NOTE**: Currently the init attributes are not exposed to the user. It will always use the "score" key
+ to select highest/lowest score. The user can use it to define how to select the top/bottom samples, e.g.
+ "score" = reward or length or any other value.
+ """
+
+ def __init__(
+ self, reduction: Reduce, top_k: int = 1, bottom_k: int = 1, key: str = "score"
+ ):
+ super().__init__(reduction)
+ self.samples: List[Dict[str, Any]] = []
+ self.top_k = top_k
+ self.bottom_k = bottom_k
+ self.key = key
+ self._top_heap = [] # min-heap for top-k
+ self._bottom_heap = [] # max-heap for bottom-k (store -value)
+ self._counter = itertools.count() # tie-breaker id generator
+ self.is_reset = True
+
+ def append(self, value: dict) -> None:
+ if not isinstance(value, dict):
+ raise ValueError(f"Expected dict, got {type(value)}")
+
+ self.is_reset = False
+ val = value.get(self.key, 0.0)
+ idx = next(self._counter) # unique tiebreaker
+
+ # If top_k or bottom_k <= 0, it means "disable" that side of filtering (i.e., keep none).
+ # maintain top-k
+ if self.top_k > 0:
+ if len(self._top_heap) < self.top_k:
+ heapq.heappush(self._top_heap, (val, idx, value))
+ else:
+ heapq.heappushpop(self._top_heap, (val, idx, value))
+
+ # maintain bottom-k
+ if self.bottom_k > 0:
+ if len(self._bottom_heap) < self.bottom_k:
+ heapq.heappush(self._bottom_heap, (-val, idx, value))
+ else:
+ heapq.heappushpop(self._bottom_heap, (-val, idx, value))
+
+ def get_value(self) -> list[dict]:
+ """Return top-k and bottom-k filtered samples."""
+ tops = [s for _, _, s in self._top_heap]
+ bottoms = [s for _, _, s in self._bottom_heap]
+ return bottoms + tops
+
+ def get_state(self) -> Dict[str, Any]:
+ """Serialize accumulator state for cross-rank reduction."""
+ return {
+ "reduction_type": self.reduction_type.value,
+ "samples": self.get_value(),
+ }
+
+ @classmethod
+ def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> list[dict]:
+ """Merge sample states across ranks."""
+ merged = []
+ for s in states:
+ merged.extend(s.get("samples", []))
+ return merged
+
+ def reset(self) -> None:
+ """Clear local samples and reset filter state."""
+ self.is_reset = True
+ self.samples.clear()
+ self._top_heap = []
+ self._bottom_heap = []
+ self._counter = itertools.count()
+
+
#############
# Collector #
#############
@@ -359,22 +500,23 @@ def reset(self) -> None:
class MetricCollector:
"""Per-rank singleton for accumulating, retrieving and flushing metrics to backends.
- A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False,
- the backend is instantiated per-rank, in the MetricCollector, otherwise it is instantiated once globally,
- in the GlobalLoggingActor.
+ Supports multiple logging backends, each with different logging modes.
+ For options, check `forge.observability.metrics.LoggerBackend` and `forge.observability.metrics.LoggingMode`.
- - Ensures one instance per process; actors call record_metric() which delegates here.
+ Behavior:
+ - Ensures one instance per rank;
+ - Using `record_metric()` delegates here;
- Init via GlobalLoggingActor -> LocalFetcherActor -> per-rank MetricCollector;
- GlobalLoggingActor flushes trigger reductions and log for any locally setup backend. Can optionally also
- return non-reduced states for global aggregation. This can be different for each backend.
- - Resets accumulators post-flush to avoid leaks across train steps;
+ return non-reduced states for global aggregation.
+ - Resets accumulators post-flush to avoid leaks across steps;
"""
- _instances: Dict[int, "MetricCollector"] = {}
+ _instances: dict[int, "MetricCollector"] = {}
_singleton_rank: int
def __new__(cls):
- """Singleton per-rank, ensures one instance per process."""
+ """Singleton per-rank, ensures one instance per rank."""
rank = current_rank().rank
if rank not in cls._instances:
@@ -383,121 +525,233 @@ def __new__(cls):
inst._singleton_rank = rank
else:
inst = cls._instances[rank]
+ # Defensive check for bugs in singleton implementation - should never fail in normal operation
if inst._singleton_rank != rank:
raise ValueError(
f"Singleton expected rank {inst._singleton_rank}, but saw {rank}"
)
return inst
- def __init__(self):
+ def __init__(self) -> None:
if hasattr(self, "_is_initialized"):
return
- self.accumulators: Dict[str, MetricAccumulator] = {}
+ self.accumulators: dict[str, MetricAccumulator] = {}
self.rank = current_rank().rank
- self.logger_backends: List[LoggerBackend] = []
+ self.per_rank_reduce_backends: list[LoggerBackend] = []
+ self.per_rank_no_reduce_backends: list[LoggerBackend] = []
+ self.global_step: int = 0 # Set on `init_backends` and updated on `flush`
self._is_initialized = False
+ self.proc_name_with_rank: str | None = None
async def init_backends(
self,
- metadata_per_primary_backend: Optional[Dict[str, Dict[str, Any]]],
- config: Dict[str, Any],
+ metadata_per_controller_backend: dict[str, dict[str, Any]] | None,
+ backend_config: dict[str, Any],
+ global_step: int = 0,
+ process_name: str | None = None,
+ run_config: dict[str, Any] | None = None,
) -> None:
- """A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False,
- the backend is instantiated per-rank, in the MetricCollector, otherwise it is only instantiated
- once globally.
+ """Initialize per-rank logger backends and MetricCollector state.
+
+ A logger backend is represented by a backend class (e.g. WandBBackend, ConsoleBackend).
+ Backends are categorized by their logging_mode. For details, see `forge.observability.metrics.LoggingMode`.
Args:
- metadata_per_primary_backend (Optional[Dict[str, Dict[str, Any]]]): Metadata from primary
- logger backend, e.g., {"wandb": {"run_id": "abc123"}}.
- config (Dict[str, Any]): Logger backend configuration, e.g. {"wandb": {"project": "my_project"}}.
+ metadata_per_controller_backend (Optional[Dict[str, Dict[str, Any]]]): Metadata from controller
+ for backends that require shared state across processes, e.g.,
+ {"wandb": {"shared_run_id": "abc123"}}.
+ backend_config (Dict[str, Any]): Backend configurations where each key is a backend name
+ and value contains logging_mode and backend-specific settings.
+ e.g., {"wandb": {"logging_mode": "per_rank_no_reduce", "project": "my_proj"}}
+ global_step (int, default 0): Initial step for logging. Can be used when
+ resuming from a checkpoint.
+ process_name (str | None): The meaningful process name for logging.
+ run_config (dict[str, Any] | None): Your application's configuration
+ (hyperparameters, dataset, model settings) to log to backends for
+ experiment tracking.
"""
if self._is_initialized:
- logger.debug(f"Rank {self.rank}: MetricCollector already initialized")
+ logger.debug(
+ f"{self.proc_name_with_rank}: MetricCollector already initialized"
+ )
return
- # instantiate local backends if any
- for backend_name, backend_config in config.items():
- if backend_config.get("reduce_across_ranks", True):
- continue # Skip local backend instantiation and use global instead
+ self.global_step = global_step
+ self.proc_name_with_rank = get_proc_name_with_rank(process_name)
+
+ self.per_rank_reduce_backends: list[LoggerBackend] = []
+ self.per_rank_no_reduce_backends: list[LoggerBackend] = []
- # get metadata from primary backend if any
- primary_metadata = {}
- if metadata_per_primary_backend:
- primary_metadata = metadata_per_primary_backend.get(backend_name, {})
+ # Initialize backends based on logging mode
+ for backend_name, cfg in backend_config.items():
+ mode = cfg["logging_mode"]
+
+ # sanity check
+ if not isinstance(mode, LoggingMode):
+ raise TypeError(
+ f"Expected LoggingMode enum for {backend_name}.logging_mode, got {type(mode)}: {mode}."
+ )
+
+ # We should never hit this. Backend will be instantiated in GlobalLoggingActor.
+ if mode == LoggingMode.GLOBAL_REDUCE:
+ logger.debug("Skipping local instantiation for GLOBAL_REDUCE.")
+ continue
+
+ # get metadata from controller backend, if any
+ controller_metadata = {}
+ if metadata_per_controller_backend:
+ controller_metadata = metadata_per_controller_backend.get(
+ backend_name, {}
+ )
# instantiate local backend
- logger_backend = get_logger_backend_class(backend_name)(backend_config)
- await logger_backend.init(
- role="local", primary_logger_metadata=primary_metadata
+ backend: LoggerBackend = get_logger_backend_class(backend_name)(**cfg)
+ await backend.init(
+ role=BackendRole.LOCAL,
+ controller_logger_metadata=controller_metadata,
+ process_name=self.proc_name_with_rank,
+ run_config=run_config,
)
- self.logger_backends.append(logger_backend)
+
+ # Categorize by logging mode
+ if mode == LoggingMode.PER_RANK_NO_REDUCE:
+ self.per_rank_no_reduce_backends.append(backend)
+ else:
+ self.per_rank_reduce_backends.append(backend)
self._is_initialized = True
- def push(self, key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None:
+ def push(self, metric: Metric) -> None:
+ """Process a metric according to configured logging modes.
+
+ Behavior depends on backend modes:
+ - PER_RANK_NO_REDUCE: Stream metric immediately to backends
+ - PER_RANK_REDUCE/GLOBAL_REDUCE: Accumulate for per step batch logging
+
+ Args:
+ metric (Metric): Metric dataclass
+
+ Example:
+ collector = MetricCollector()
+ metric = Metric("loss", 0.5, Reduce.MEAN)
+ collector.push(metric) # Streams immediately if no_reduce, else accumulates
+ """
+ # sanity check
if not self._is_initialized:
- raise ValueError("Collector not initialized—call init first")
+ log_once(
+ logger,
+ level=logging.WARNING,
+ msg=(
+ f"Skipping metric collection for {get_proc_name_with_rank()}."
+ " Metric logging backends (e.g. wandb) were not initialized."
+ " This happens when you try to use `record_metric` before calling `init_backends`."
+ " To disable this warning, please call in your main file:\n"
+ "`mlogger = await get_or_create_metric_logger(process_name='Controller')`\n"
+ "`await mlogger.init_backends.call_one(logging_config)`\n"
+ "or set env variable `FORGE_DISABLE_METRICS=True`"
+ ),
+ )
+ return
- if key not in self.accumulators:
- self.accumulators[key] = reduction.accumulator_class(reduction)
+ # Validate metric object
+ if not isinstance(metric, Metric):
+ raise TypeError(
+ f"Expected {Metric} object, got {metric} of type {type(metric)}"
+ )
+
+ # For PER_RANK_NO_REDUCE backends: stream without reduce
+ for backend in self.per_rank_no_reduce_backends:
+ if metric.reduction == Reduce.SAMPLE:
+ asyncio.create_task(backend.log_samples([metric], self.global_step))
+ else:
+ backend.log_stream(metric=metric, global_step=self.global_step)
- self.accumulators[key].append(value)
+ # Always accumulate for reduction and state return
+ key = metric.key
+ if key not in self.accumulators:
+ self.accumulators[key] = metric.reduction.accumulator_class(
+ metric.reduction
+ )
+ self.accumulators[key].append(metric.value)
async def flush(
- self, step: int, return_state: bool = False
- ) -> Dict[str, Dict[str, Any]]:
+ self, global_step: int, return_state: bool = False
+ ) -> dict[str, dict[str, Any]]:
"""Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True.
Args:
- step (int): Step used by backends to align metrics on the same x-axis
+ global_step (int): step used by backends to align metrics on the same x-axis
return_state (bool): Used by GlobalLoggingActor for reduction across all ranks.
If False, returns empty dict, else returns the state of all metrics collected.
Returns:
- Dict[str, Dict[str, Dict[str, Any]]]: Dict of {metric_key: metric_state},
+ dict[str, dict[str, Any]]: Dict of {metric_key: metric_state},
e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}.
"""
if not self._is_initialized:
- logger.debug(
- f"Collector not yet initialized for {get_actor_name_with_rank()}. Call init_backends first."
+ log_once(
+ logger,
+ level=logging.WARNING,
+ msg=f"Cannot flush collected metrics for {get_proc_name_with_rank()}. "
+ " MetricCollector.flush() called before init_backends()."
+ "\nPlease call in your main file:\n"
+ "`mlogger = await get_or_create_metric_logger(process_name='Controller')`\n"
+ "`await mlogger.init_backends.call_one(logging_config)`\n"
+ "before calling `flush`",
)
return {}
if not self.accumulators:
logger.debug(
- f"Collector rank {get_actor_name_with_rank()}: No metrics to flush for step {step}"
+ f"Collector {self.proc_name_with_rank}: No metrics to flush for global_step {global_step}"
)
return {}
# Snapshot states and reset immediately
states = {}
for key, acc in self.accumulators.items():
+ # Skip state if nothing was accumulated
+ if acc.is_reset:
+ continue
states[key] = acc.get_state()
acc.reset()
- # Reduce metrics from states for logging if any per-rank backend
- if self.logger_backends:
- metrics = {}
- for key, state in states.items():
- acc_class = Reduce(state["reduction_type"]).accumulator_class
- metrics[key] = acc_class.get_reduced_value_from_states([state])
-
- # Log to local logger_backends
- for logger_backend in self.logger_backends:
- await logger_backend.log(metrics, step)
+ # Reduce and log to PER_RANK_REDUCE backends only (NO_REDUCE backends already logged in push)
+ if self.per_rank_reduce_backends:
+ metrics_for_backends = reduce_metrics_states([states])
+
+ # Split into scalar metrics and sample metrics
+ scalar_metrics = [
+ m for m in metrics_for_backends if m.reduction != Reduce.SAMPLE
+ ]
+ sample_metrics = [
+ m for m in metrics_for_backends if m.reduction == Reduce.SAMPLE
+ ]
+
+ for backend in self.per_rank_reduce_backends:
+ if scalar_metrics:
+ await backend.log_batch(scalar_metrics, global_step)
+ if sample_metrics:
+ await backend.log_samples(sample_metrics, global_step)
+
+ # Update step counter for streaming backends
+ # Note: This is incremented AFTER flush completes, so metrics recorded between
+ # flush(N) and flush(N+1) will stream with global_step=N+1.
+ self.global_step = global_step + 1
return states if return_state else {}
async def shutdown(self):
"""Shutdown logger_backends if initialized."""
+
if not self._is_initialized:
logger.debug(
- f"Collector for {get_actor_name_with_rank()} not initialized. Skipping shutdown"
+ f"Collector for {self.proc_name_with_rank} not initialized. Skipping shutdown"
)
return
- for logger_backend in self.logger_backends:
- await logger_backend.finish()
+ for backend in self.per_rank_reduce_backends + self.per_rank_no_reduce_backends:
+ await backend.finish()
###########
@@ -506,65 +760,131 @@ async def shutdown(self):
class LoggerBackend(ABC):
- """Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc."""
+ """Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc.
- def __init__(self, logger_backend_config: Dict[str, Any]):
- self.logger_backend_config = logger_backend_config
+ Args:
+ logging_mode: Logging behavior mode.
+ per_rank_share_run: Whether ranks share run. Default False.
+ **kwargs: Backend-specific arguments (e.g., project, name, tags for WandB).
+ """
+
+ def __init__(
+ self, *, logging_mode: LoggingMode, per_rank_share_run: bool = False, **kwargs
+ ) -> None:
+ self.logging_mode = logging_mode
+ self.per_rank_share_run = per_rank_share_run
+ self.backend_kwargs = kwargs
@abstractmethod
async def init(
self,
- role: str,
- primary_logger_metadata: Optional[Dict[str, Any]] = None,
+ role: BackendRole,
+ controller_logger_metadata: dict[str, Any] | None = None,
+ process_name: str | None = None,
+ run_config: dict[str, Any] | None = None,
) -> None:
"""
Initializes backend, e.g. wandb.run.init().
Args:
- role (str): "global" (controller/primary) or "local" (per-rank/secondary).
- Can be used to behave differently for primary vs secondary roles.
- primary_logger_metadata (Optional[Dict[str, Any]]): From global backend for
+ role (BackendRole): BackendRole.GLOBAL (controller) or BackendRole.LOCAL (per-rank).
+ Can be used to behave differently for controller vs rank roles.
+ controller_logger_metadata (dict[str, Any] | None): From global backend for
backend that required shared info, e.g. {"shared_run_id": "abc123"}.
+ process_name (str | None): Process name for logging.
+ run_config (dict[str, Any] | None): Your application's configuration
+ (hyperparameters, dataset, model settings) to log to backend for
+ experiment tracking.
Raises: ValueError if missing metadata for shared local init.
"""
- if primary_logger_metadata is None:
- primary_logger_metadata = {}
pass
- async def log(self, metrics: Dict[str, Any], step: int) -> None:
+ @abstractmethod
+ async def log_batch(
+ self, metrics: list[Metric], global_step: int, *args, **kwargs
+ ) -> None:
+ """Log batch of accumulated metrics to backend
+
+ Args:
+ metrics: List of Metric objects to log.
+ global_step: Step number for x-axis alignment across metrics."""
+ pass
+
+ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
+ """Stream single metric to backend immediately.
+
+ NOTE: This method is called synchronously.
+ If your backend requires async I/O operations:
+ - Use asyncio.create_task() for fire-and-forget logging
+ - Consider internal buffering to avoid blocking the caller
+
+ Example for async backend:
+ def log_stream(self, metric, global_step):
+ asyncio.create_task(self._async_log(metric, global_step))
+ """
pass
+ @abstractmethod
+ async def log_samples(self, samples: List[Metric], step: int) -> None:
+ """Log samples to backend.
+
+ Args:
+ samples: List of Metric objects to log.
+ step: Step number for x-axis alignment across metrics.
+ """
+ pass
+
+ @abstractmethod
async def finish(self) -> None:
pass
- def get_metadata_for_secondary_ranks(self) -> Optional[Dict[str, Any]]:
- """Return sharable state after primary init (e.g., for shared modes). Called only on globals."""
+ def get_metadata_for_secondary_ranks(self) -> dict[str, Any] | None:
+ """Return sharable state after controller init (e.g., for shared modes). Called only on controller backends."""
return None
class ConsoleBackend(LoggerBackend):
"""Simple console logging of metrics."""
- def __init__(self, logger_backend_config: Dict[str, Any]):
- super().__init__(logger_backend_config)
+ def __init__(
+ self, *, logging_mode: LoggingMode, per_rank_share_run: bool = False, **kwargs
+ ) -> None:
+ super().__init__(
+ logging_mode=logging_mode, per_rank_share_run=per_rank_share_run, **kwargs
+ )
+ self.process_name = None
async def init(
self,
- role: str,
- primary_logger_metadata: Optional[Dict[str, Any]] = None,
+ role: BackendRole,
+ controller_logger_metadata: dict[str, Any] | None = None,
+ process_name: str | None = None,
+ run_config: dict[str, Any] | None = None,
+ ) -> None:
+ self.process_name = process_name
+
+ async def log_batch(
+ self, metrics: list[Metric], global_step: int, *args, **kwargs
) -> None:
- self.prefix = (
- get_actor_name_with_rank()
- if self.logger_backend_config.get("reduce_across_ranks", True)
- else "GLOBAL"
+ metrics_str = "\n".join(
+ f" {metric.key}: {metric.value}"
+ for metric in sorted(metrics, key=lambda m: m.key)
)
+ logger.info(
+ f"=== [{self.process_name}] - METRICS STEP {global_step} ===\n{metrics_str}\n==============================\n"
+ )
+
+ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
+ logger.info(f"{metric.key}: {metric.value}")
+
+ async def log_samples(self, samples: List[Metric], step: int) -> None:
+ """Pretty-print sample-level logs to console."""
- async def log(self, metrics: Dict[str, Any], step: int) -> None:
- logger.info(f"=== [{self.prefix}] - METRICS STEP {step} ===")
- for key, value in sorted(metrics.items()):
- logger.info(f" {key}: {value}")
- logger.info("==============================\n")
+ for sample in samples:
+ table_name, table_rows = sample.key, sample.value
+ logger.info(f"[{table_name}] ({len(table_rows)} samples)")
+ logger.info(json.dumps(table_rows, indent=2, ensure_ascii=False))
async def finish(self) -> None:
pass
@@ -572,121 +892,235 @@ async def finish(self) -> None:
class WandbBackend(LoggerBackend):
"""
- Weights & Biases logging backend for distributed training.
+ Weights & Biases logging backend.
- Supports 3 types of modes as described in https://docs.wandb.ai/guides/track/log/distributed-training/:
- Track a single process: reduce_across_ranks=True
- Track each process separately: reduce_across_ranks=False, share_run_id=False
- Track all processes to a single run: reduce_across_ranks=False, share_run_id=True
+ For logging mode details, see `forge.observability.metrics.LoggingMode` documentation.
+
+ More details on wandb distributed logging: https://docs.wandb.ai/guides/track/log/distributed-training/
Configuration:
- reduce_across_ranks (bool, default True): If True, log reduced metrics only from controller (global mode).
- If False, enables per-rank logging; then use share_run_id to pick mode.
- share_run_id (bool, default False): Only used if reduce_across_ranks=False.
- True -> shared run across ranks; False -> separate runs per rank.
- project (str): WandB project name
- group (str, optional): WandB group name for organizing runs. Defaults to "experiment_group"
+ logging_mode (LoggingMode): Determines logging behavior.
+ per_rank_share_run (bool, default False): For per-rank modes, whether to share run ID across ranks.
+ If true, a single wandb run is created and all ranks log to it. Particularly useful for
+ logging with no_reduce to capture time-based streams. Not recommended if reducing values.
+ **kwargs: Any argument accepted by wandb.init() (e.g., project, group, name, tags, notes, etc.)
+
+ Example:
+ WandbBackend(
+ logging_mode=LoggingMode.PER_RANK_REDUCE,
+ per_rank_share_run=False,
+ project="my_project",
+ group="exp_group",
+ name="my_experiment",
+ tags=["rl", "v2"],
+ notes="Testing new reward"
+ )
"""
- def __init__(self, logger_backend_config: Dict[str, Any]):
- super().__init__(logger_backend_config)
- self.project = logger_backend_config["project"]
- self.group = logger_backend_config.get("group", "experiment_group")
- self.name = None
- self.run = None
- self.reduce_across_ranks = logger_backend_config.get(
- "reduce_across_ranks", True
+ def __init__(
+ self, *, logging_mode: LoggingMode, per_rank_share_run: bool = False, **kwargs
+ ) -> None:
+ super().__init__(
+ logging_mode=logging_mode, per_rank_share_run=per_rank_share_run, **kwargs
)
- self.share_run_id = logger_backend_config.get("share_run_id", False)
+ self.run = None
+ self.process_name = None
+ self._tables: dict[str, "wandb.Table"] = {}
async def init(
self,
- role: str,
- primary_logger_metadata: Optional[Dict[str, Any]] = None,
+ role: BackendRole,
+ controller_logger_metadata: dict[str, Any] | None = None,
+ process_name: str | None = None,
+ run_config: dict[str, Any] | None = None,
) -> None:
-
- if primary_logger_metadata is None:
- primary_logger_metadata = {}
-
- if role not in ["global", "local"]:
- raise ValueError(
- f"Invalid role {role} for WandbBackend init. Must be 'global' or 'local'."
- )
-
- self.name = (
- get_actor_name_with_rank() if role == "local" else "global_controller"
- )
-
- # Default global mode: only inits on controller
- if self.reduce_across_ranks:
- if role != "global":
- logger.debug(
- f"Skipped init for global mode (reduce_across_ranks=True) and {role} role."
- )
+ if controller_logger_metadata is None:
+ controller_logger_metadata = {}
+
+ # Pop name, if any, to concat to process_name.
+ run_name = self.backend_kwargs.pop("name", None)
+ self.process_name = process_name
+ self.run_config = run_config
+
+ # Format run name based on mode and role
+ if self.logging_mode == LoggingMode.GLOBAL_REDUCE:
+ if role != BackendRole.GLOBAL:
+ logger.warning(f"Skipped init for GLOBAL_REDUCE mode and {role} role.")
return
- await self._init_global()
-
- # Per-rank modes based on share_run_id bool
- elif role == "global" and self.share_run_id:
- await self._init_shared_global()
-
- elif role == "local":
- if self.share_run_id:
- await self._init_shared_local(primary_logger_metadata)
+ # use name as-is, no need to append controller process_name
+ await self._init_global(run_name)
+
+ elif role == BackendRole.GLOBAL and self.per_rank_share_run:
+ # use name as-is, no need to append controller process_name
+ await self._init_shared_global(run_name)
+
+ elif role == BackendRole.LOCAL:
+ # Per-rank: append process_name
+ run_name = f"{run_name}_{process_name}" if run_name else process_name
+
+ if self.per_rank_share_run:
+ shared_id = controller_logger_metadata.get("shared_run_id")
+ if shared_id is None:
+ raise ValueError(
+ f"Shared ID required but not provided for {process_name} backend init"
+ )
+ await self._init_shared_local(run_name, shared_id, process_name)
else:
- await self._init_per_rank()
+ await self._init_per_rank(run_name)
- async def _init_global(self):
+ async def _init_global(self, run_name: str | None):
import wandb
- self.run = wandb.init(project=self.project, group=self.group)
+ self.run = wandb.init(
+ name=run_name, config=self.run_config, **self.backend_kwargs
+ )
- async def _init_per_rank(self):
+ async def _init_per_rank(self, run_name: str):
import wandb
- self.run = wandb.init(project=self.project, group=self.group, name=self.name)
+ self.run = wandb.init(
+ name=run_name, config=self.run_config, **self.backend_kwargs
+ )
- async def _init_shared_global(self):
+ async def _init_shared_global(self, run_name: str | None):
import wandb
settings = wandb.Settings(
mode="shared", x_primary=True, x_label="controller_primary"
)
- self.run = wandb.init(project=self.project, group=self.group, settings=settings)
+ self.run = wandb.init(
+ name=run_name,
+ config=self.run_config,
+ settings=settings,
+ **self.backend_kwargs,
+ )
- async def _init_shared_local(self, primary_metadata: Dict[str, Any]):
+ async def _init_shared_local(
+ self, run_name: str, shared_id: str, process_name: str
+ ):
import wandb
- shared_id = primary_metadata.get("shared_run_id")
- if shared_id is None:
- raise ValueError(
- f"Shared ID required but not provided for {self.name} backend init"
- )
- settings = wandb.Settings(mode="shared", x_primary=False, x_label=self.name)
+ # Clear any stale service tokens that might be pointing to dead processes
+ # In multiprocessing environments, WandB service tokens can become stale and point
+ # to dead service processes. This causes wandb.init() to hang indefinitely trying
+ # to connect to non-existent services. Clearing forces fresh service connection.
+ from wandb.sdk.lib.service import service_token
+
+ service_token.clear_service_in_env()
+
+ settings = wandb.Settings(mode="shared", x_primary=False, x_label=process_name)
self.run = wandb.init(
+ name=run_name,
id=shared_id,
- project=self.project,
- group=self.group,
+ config=self.run_config,
settings=settings,
+ **self.backend_kwargs,
)
- async def log(self, metrics: Dict[str, Any], step: int) -> None:
- if self.run:
- log_data = {**metrics, "global_step": step}
- self.run.log(log_data)
- logger.info(f"WandbBackend: Logged {len(metrics)} metrics at step {step}")
- else:
- logger.debug(f"WandbBackend: No run started, skipping log for {self.name}")
+ async def log_batch(
+ self, metrics: list[Metric], global_step: int, *args, **kwargs
+ ) -> None:
+ if not self.run:
+ logger.debug(
+ f"WandbBackend: No run started, skipping log for {self.process_name}"
+ )
+ return
- def get_metadata_for_secondary_ranks(self) -> Dict[str, Any]:
- if self.run and not self.reduce_across_ranks and self.share_run_id:
+ # Convert metrics to WandB log format
+ log_data = {}
+ for metric in metrics:
+ log_data[metric.key] = metric.value
+
+ self.run.log(log_data, step=global_step)
+ logger.info(
+ f"WandbBackend: Logged {len(metrics)} metrics at step {global_step}"
+ )
+
+ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
+ """Stream single metric to WandB with both step and timestamp."""
+ if not self.run:
+ return
+
+ # Log with custom timestamp for precision
+ # Users can choose x-axis as timestamp in WandB UI and display as datetime
+ log_data = {
+ metric.key: metric.value,
+ "timestamp": metric.timestamp,
+ }
+
+ # note: here we dont use step since wandb keeps only the latest value for each step
+ self.run.log(log_data)
+
+ async def log_samples(self, samples: List[Metric], step: int) -> None:
+ """Log sample-level data incrementally to persistent WandB Tables."""
+ import wandb
+
+ if not self.run:
+ return
+
+ for sample in samples:
+ table_name, table_rows = sample.key, sample.value
+ if not table_rows:
+ continue
+
+ # If table doesn't exist yet, create it in INCREMENTAL mode
+ if table_name not in self._tables:
+ # Collect all unique columns from all rows
+ columns = set()
+ for row in table_rows:
+ columns.update(row.keys())
+ columns = sorted(columns) # Sort for consistent column ordering
+ table = wandb.Table(columns=columns, log_mode="INCREMENTAL")
+ self._tables[table_name] = table
+ logger.debug(
+ f"WandbBackend: Created new incremental table: {table_name} with columns: {columns}"
+ )
+ else:
+ table = self._tables[table_name]
+
+ # Add rows (fill missing columns with None)
+ for s in table_rows:
+ # Check for extra columns not in the table schema
+ extra_columns = set(s.keys()) - set(table.columns)
+ if extra_columns:
+ logger.warning(
+ f"WandbBackend: Row has extra columns not in table '{table_name}': {sorted(extra_columns)}. "
+ f"These will be ignored."
+ )
+ values = [s.get(c) for c in table.columns]
+ table.add_data(*values)
+
+ # Log the same table object (INCREMENTAL update)
+ # table_name has to end with _table to be recognized by wandb
+ if not table_name.endswith("_table"):
+ table_name += "_table"
+ self.run.log({f"{table_name}": table})
+
+ def get_metadata_for_secondary_ranks(self) -> dict[str, Any]:
+ if self.run and self.per_rank_share_run:
return {"shared_run_id": self.run.id}
return {}
async def finish(self) -> None:
+ import wandb
+
if self.run:
+ """
+ Convert each incremental table to immutable before finishing
+ as recommended by wandb:
+ https://docs.wandb.ai/models/tables/log_tables#incremental-mode
+ """
+ for table_name, incr_table in self._tables.items():
+ final_table = wandb.Table(
+ columns=incr_table.columns,
+ data=incr_table.data,
+ log_mode="IMMUTABLE",
+ )
+ self.run.log({table_name: final_table})
+ logger.debug(f"WandbBackend: Finalized table {table_name}")
+
self.run.finish()
- logger.info(f"WandbBackend {self.name}: Finished run")
+ logger.info(f"WandbBackend {self.process_name}: Finished run")
def get_logger_backend_class(cls_name: str) -> type[LoggerBackend]:
diff --git a/src/forge/observability/perf_tracker.py b/src/forge/observability/perf_tracker.py
index e85b81e26..13c895346 100644
--- a/src/forge/observability/perf_tracker.py
+++ b/src/forge/observability/perf_tracker.py
@@ -3,21 +3,22 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+
import inspect
import logging
-import os
import threading
import time
-
from concurrent.futures import Future, ThreadPoolExecutor
from functools import lru_cache, wraps
-from typing import List, Optional, Protocol, Tuple
+from typing import Protocol
import torch
-from forge.env_constants import DISABLE_PERF_METRICS, METRIC_TIMER_USES_CUDA
+from forge.env import DISABLE_PERF_METRICS, FORGE_DISABLE_METRICS, METRIC_TIMER_USES_GPU
from forge.observability.metrics import record_metric, Reduce
+logger = logging.getLogger(__name__)
+
# Thread-local memory tracking state
_local = threading.local()
@@ -44,7 +45,6 @@ def _warn_nested_memory_tracking(prefix: str) -> None:
"""
-
class Tracer:
==========
"""
@@ -107,11 +107,13 @@ def __init__(
self.prefix = prefix
self.track_memory = track_memory
self.time_with_gpu = timer == "gpu"
- self._disable = os.getenv(DISABLE_PERF_METRICS, "false") == "true"
+ self._disable = (
+ DISABLE_PERF_METRICS.get_value() or FORGE_DISABLE_METRICS.get_value()
+ )
self._active = False
# Timing state
- self._timer: Optional[_TimerProtocol] = None
+ self._timer: _TimerProtocol | None = None
# Memory tracking state
self._memory_started = False
@@ -124,9 +126,20 @@ def start(self) -> None:
raise ValueError("Tracer has already been started")
# Start timing (always enabled)
- time_with_gpu_events = (
- os.getenv(METRIC_TIMER_USES_CUDA, str(self.time_with_gpu)).lower() == "true"
- ) and torch.cuda.is_available()
+
+ # TODO - follow up on if this env var behavior makes sense.
+ # METRIC_TIMER_USES_GPU env var overrides the timer parameter
+ metric_timer_uses_gpu = METRIC_TIMER_USES_GPU.get_value()
+ if metric_timer_uses_gpu is not None:
+ # Env var is set - convert string to bool if needed
+ if isinstance(metric_timer_uses_gpu, str):
+ use_gpu = metric_timer_uses_gpu.lower() in ("true", "1", "yes")
+ else:
+ use_gpu = bool(metric_timer_uses_gpu)
+ else:
+ # Env var not set - use the timer parameter
+ use_gpu = self.time_with_gpu
+ time_with_gpu_events = use_gpu and torch.cuda.is_available()
self._timer = _TimerCUDA() if time_with_gpu_events else _TimerCPU()
self._timer.start()
@@ -150,10 +163,9 @@ def stop(self) -> None:
if not self._active:
raise ValueError("Tracer must be started before calling stop")
- # Stop timing (always enabled)
- # step("end") is dropped from steps, but included in total sum
- self._timer.step("end") # pyre-ignore
- self._record_timing_metrics()
+ # Stop timing
+ durations, stop_step_ms = self._timer.get_all_durations() # pyre-ignore
+ self._record_timing_metrics(durations, stop_step_ms)
self._timer = None
# Stop memory tracking
@@ -174,7 +186,7 @@ def _start_memory_tracking(self) -> None:
if should_track:
_set_memory_active(True)
- torch.cuda.reset_max_memory_allocated()
+ torch.cuda.reset_peak_memory_stats()
self._start_mem = torch.cuda.memory_allocated()
self._memory_started = True
@@ -190,34 +202,29 @@ def _stop_memory_tracking(self) -> None:
)
record_metric(f"{self.prefix}/memory_peak_max_gb", peak_mem, Reduce.MAX)
_set_memory_active(False)
- torch.cuda.reset_max_memory_allocated()
+ torch.cuda.reset_peak_memory_stats()
self._memory_started = False
- def _record_timing_metrics(self) -> None:
- durations = self._timer.get_all_durations() # pyre-ignore
-
- # Total: sum all recorded durations (full timeline including end)
- total_ms = sum(d_ms for name, d_ms in durations)
+ def _record_timing_metrics(
+ self, durations: list[tuple[str, float]], stop_step_ms: float
+ ) -> None:
+ total_ms = sum(d_ms for _, d_ms in durations) + stop_step_ms
total_s = total_ms / 1000.0
record_metric(f"{self.prefix}/total_duration_avg_s", total_s, Reduce.MEAN)
record_metric(f"{self.prefix}/total_duration_max_s", total_s, Reduce.MAX)
- # Steps: record each individually (drop last "end")
- for name, d_ms in durations[:-1]:
+ for name, d_ms in durations:
d_s = d_ms / 1000.0
record_metric(f"{self.prefix}/{name}/duration_avg_s", d_s, Reduce.MEAN)
record_metric(f"{self.prefix}/{name}/duration_max_s", d_s, Reduce.MAX)
class _TimerProtocol(Protocol):
- def start(self) -> None:
- ...
+ def start(self) -> None: ...
- def step(self, name: str) -> None:
- ...
+ def step(self, name: str) -> None: ...
- def get_all_durations(self) -> List[Tuple[str, float]]:
- ...
+ def get_all_durations(self) -> tuple[list[tuple[str, float]], float]: ...
class _TimerCPU(_TimerProtocol):
@@ -226,8 +233,8 @@ class _TimerCPU(_TimerProtocol):
"""
def __init__(self) -> None:
- self._durations: List[Tuple[str, float]] = []
- self._chain_start: Optional[float] = None
+ self._durations: list[tuple[str, float]] = []
+ self._chain_start: float | None = None
def start(self) -> None:
# Reset state for reuse
@@ -242,24 +249,38 @@ def step(self, name: str) -> None:
self._durations.append((name, delta_ms))
self._chain_start = now
- def get_all_durations(self) -> List[Tuple[str, float]]:
- return self._durations[:]
+ def get_all_durations(self) -> tuple[list[tuple[str, float]], float]:
+ """Retrieve list of (step_name, duration) tuples and last step duration
+ between tracer.stop and the last step (or start if none)."""
+ stop_step_ms = 0.0
+ if self._chain_start is not None:
+ now = time.perf_counter()
+ stop_step_ms = (now - self._chain_start) * 1000
+ return self._durations[:], stop_step_ms
class _TimerCUDA(_TimerProtocol):
"""CUDA timing backend with non-blocking events and futures.
Uses a thread pool to poll CUDA events asynchronously without blocking the main thread.
+
+ Example:
+ timer = _TimerCUDA()
+ timer.start()
+ # torch.mm(a, b) # ~100ms GPU
+ timer.step("matmul")
+ # torch.mm(c, d) # ~200ms
+ durs_steps, stop_step_ms = timer.get_all_durations() # ([( "matmul", 100 )], 200)
"""
def __init__(self, max_workers: int = 2) -> None:
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available for timing")
self._executor = ThreadPoolExecutor(max_workers=max_workers)
- self._futures: List[
- Tuple[str, Future[float], int]
- ] = [] # (name, future, submission_index)
- self._durations: List[Tuple[str, float]] = []
- self._chain_start: Optional[torch.cuda.Event] = None
+ self._futures: list[tuple[str, Future[float], int]] = (
+ []
+ ) # (name, future, submission_index)
+ self._durations: list[tuple[str, float]] = []
+ self._chain_start: torch.cuda.Event | None = None
def start(self) -> None:
"""Call before any steps. Clear state for reuse; record initial event on current stream."""
@@ -277,7 +298,6 @@ def step(self, name: str) -> None:
Args:
name: Label for this segment's duration
"""
- # Submit polling future; chain to next event.
if self._chain_start is None:
raise ValueError("Timer must be started before calling step")
@@ -285,66 +305,63 @@ def step(self, name: str) -> None:
end_event = torch.cuda.Event(enable_timing=True)
end_event.record(stream)
- def _compute_elapsed(start_event, end_event):
- # Poll with backoff: starts fast (1ms), grows to cap (50ms) for mixed workloads.
- sleep_time = 0.001 # Start at 1ms
- while not end_event.query():
- time.sleep(sleep_time)
- sleep_time = min(sleep_time * 1.5, 0.05) # Backoff, cap at 50ms
- return start_event.elapsed_time(end_event)
-
- future = self._executor.submit(_compute_elapsed, self._chain_start, end_event)
+ future = self._executor.submit(self._poll_elapsed, self._chain_start, end_event)
index = len(self._futures)
self._futures.append((name, future, index))
-
if len(self._futures) >= 5: # clean up every 5
self._collect_completed_futures()
self._chain_start = end_event
- def _collect_completed_futures(self) -> None:
+ def _poll_elapsed(
+ self, start_event: torch.cuda.Event, end_event: torch.cuda.Event
+ ) -> float:
+ """Compute elapsed time after polling with backoff."""
+ # Poll until ready
+ sleep_time = 0.001 # Start at 1ms
+ while not end_event.query():
+ time.sleep(sleep_time)
+ sleep_time = min(sleep_time * 1.5, 0.05) # Backoff, cap at 50ms
+ return start_event.elapsed_time(end_event)
+
+ def _collect_completed_futures(self, wait_till_done: bool = False) -> None:
"""Drain done futures to avoid memory leak; update durations in submission order."""
- completed = []
still_pending = []
for name, future, idx in self._futures:
- if future.done():
- try:
- dur = future.result()
- completed.append((idx, name, dur))
- except Exception as e:
- raise RuntimeError(f"Timing failed for {name}: {e}") from e
+ if future.done() or wait_till_done:
+ dur = future.result()
+ self._durations.append((name, dur))
else:
still_pending.append((name, future, idx))
- # Sort completed by submission index to preserve order
- completed.sort(key=lambda x: x[0])
- for _, name, dur in completed:
- self._durations.append((name, dur))
-
self._futures = still_pending
- def get_all_durations(self) -> List[Tuple[str, float]]:
- """Retrieve list of (name, duration) tuples in submission order after waiting for background polls to finish."""
- # Wait and collect if pendings; return durations.
- self._collect_completed_futures()
- completed = []
- for name, future, idx in self._futures:
- try:
- dur = future.result()
- completed.append((idx, name, dur))
- except Exception as e:
- raise RuntimeError(f"Timing failed for {name}: {e}") from e
-
- # Sort by submission index to preserve order
- completed.sort(key=lambda x: x[0])
- for _, name, dur in completed:
- self._durations.append((name, dur))
+ def get_all_durations(self) -> tuple[list[tuple[str, float]], float]:
+ """Retrieve list of (step_name, duration) tuples and last step duration
+ between tracer.stop and the last step (or start if none). Order of tuples is random.
+ """
+ # Final timing since last step (or start) until this function is called
+ stop_step = f"_stop_step_{id(self)}"
+ self.step(stop_step)
+ # Wait on remaining futures
+ self._collect_completed_futures(wait_till_done=True)
self._futures.clear()
- return self._durations[:]
+
+ # Extract stop_step_ms
+ stop_step_ms = 0.0
+ durations = [
+ (name, duration) for name, duration in self._durations if name != stop_step
+ ]
+ for name, duration in self._durations:
+ if name == stop_step:
+ stop_step_ms = duration
+ break
+
+ return durations, stop_step_ms
def __del__(self) -> None:
- # Fallback cleanup in finalizer; ignores errors to avoid shutdown noise.
+ # Fallback cleanup in finalizer
try:
self._executor.shutdown(wait=True)
except Exception:
diff --git a/src/forge/observability/utils.py b/src/forge/observability/utils.py
new file mode 100644
index 000000000..811bbfe41
--- /dev/null
+++ b/src/forge/observability/utils.py
@@ -0,0 +1,51 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from typing import Optional
+
+from monarch.actor import context, current_rank
+
+logger = logging.getLogger(__name__)
+
+
+def get_proc_name_with_rank(proc_name: Optional[str] = None) -> str:
+ """
+ Returns a unique identifier for the current rank from Monarch actor context.
+
+ Multiple ranks from the same ProcMesh will share the same ProcMesh hash suffix,
+ but have different rank numbers.
+
+ Format: "{ProcessName}_{ProcMeshHash}_r{rank}" where:
+ - ProcessName: The provided proc_name (e.g., "TrainActor") or extracted from actor_name if None.
+ - ProcMeshHash: Hash suffix identifying the ProcMesh (e.g., "1abc2def")
+ - rank: Local rank within the ProcMesh (0, 1, 2, ...)
+
+ Note: If called from the main process (e.g. main.py), returns "client_r0".
+
+ Args:
+ proc_name: Optional override for process name. If None, uses actor_id.actor_name.
+
+ Returns:
+ str: Unique identifier per rank (e.g., "TrainActor_1abc2def_r0" or "client_r0").
+ """
+ ctx = context()
+ actor_id = ctx.actor_instance.actor_id
+ actor_name = actor_id.actor_name
+ rank = current_rank().rank
+
+ # If proc_name provided, extract procmesh hash from actor_name and combine
+ if proc_name is not None:
+ parts = actor_name.split("_")
+ if len(parts) > 1:
+ replica_hash = parts[-1] # (e.g., "MyActor_1abc2def" -> "1abc2def")
+ return f"{proc_name}_{replica_hash}_r{rank}"
+ else:
+ # if a direct process (e.g. called from main), actor_name == "client" -> len(parts) == 1
+ return f"{proc_name}_r{rank}"
+
+ # No proc_name override - use full actor_name with rank
+ return f"{actor_name}_r{rank}"
diff --git a/src/forge/rl/__init__.py b/src/forge/rl/__init__.py
new file mode 100644
index 000000000..025cfcb43
--- /dev/null
+++ b/src/forge/rl/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from forge.rl.advantage import ComputeAdvantages
+from forge.rl.collate import collate
+from forge.rl.grading import RewardActor
+from forge.rl.types import Episode, Group, Policy
+
+__all__ = [
+ "Episode",
+ "Group",
+ "Policy",
+ "collate",
+ "ComputeAdvantages",
+ "RewardActor",
+]
diff --git a/src/forge/rl/advantage.py b/src/forge/rl/advantage.py
new file mode 100644
index 000000000..b7ef0a416
--- /dev/null
+++ b/src/forge/rl/advantage.py
@@ -0,0 +1,26 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass
+
+import torch
+
+from forge.controller.actor import ForgeActor
+from forge.rl.types import Group
+from monarch.actor import endpoint
+
+
+# TODO: this doesn't need to be an actor
+@dataclass
+class ComputeAdvantages(ForgeActor):
+ @endpoint
+ async def compute(self, group: Group) -> list[float]:
+ # TODO: add batch processing
+ rewards = torch.tensor([[e.reward for e in group]])
+ mean = rewards.mean(1, keepdim=True)
+ std = rewards.std(1, keepdim=True)
+ advantages = (rewards - mean) / (std + 1e-4)
+ return advantages.squeeze(0).tolist()
diff --git a/src/forge/rl/collate.py b/src/forge/rl/collate.py
new file mode 100644
index 000000000..456fb8caa
--- /dev/null
+++ b/src/forge/rl/collate.py
@@ -0,0 +1,48 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Any
+
+import torch
+
+from forge.rl.types import Group
+
+
+def collate(
+ batches: list[Group],
+) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
+ """
+ Collates a list of batches into a single batch of inputs and targets.
+ Each batch is a list of episodes, and each episode is a dict of tensors.
+ """
+ inputs = []
+ targets = []
+ for batch in batches:
+ request = [e.request_tensor for e in batch]
+ request = torch.stack(request) # [b x s]
+
+ response = [e.response_tensor for e in batch]
+ response = torch.stack(response) # [b x s]
+
+ ref_logprobs = [e.ref_logprobs for e in batch]
+ ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s]
+
+ advantages = [e.advantage for e in batch]
+ advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1]
+
+ pad_id = batch[0].pad_id
+ mask = response != pad_id
+
+ input = {"tokens": torch.cat([request, response], dim=1)}
+ target = {
+ "response": response,
+ "ref_logprobs": ref_logprobs,
+ "advantages": advantages,
+ "padding_mask": mask,
+ }
+ inputs.append(input)
+ targets.append(target)
+ return inputs, targets
diff --git a/src/forge/rl/grading.py b/src/forge/rl/grading.py
new file mode 100644
index 000000000..bfa9f8e63
--- /dev/null
+++ b/src/forge/rl/grading.py
@@ -0,0 +1,55 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass
+from typing import Callable
+
+from forge.controller.actor import ForgeActor
+from forge.observability.metrics import record_metric, Reduce
+
+from monarch.actor import endpoint
+
+
+@dataclass
+class RewardActor(ForgeActor):
+ reward_functions: list[Callable]
+
+ @endpoint
+ async def evaluate_response(
+ self, prompt: str, response: str, target: str
+ ) -> (dict[str, float], float):
+ total_rewards = 0.0
+ reward_breakdown = {} # reward breakdown by function
+ for reward_fn in self.reward_functions:
+ reward = reward_fn(prompt, response, target)
+ total_rewards += reward
+
+ # Get a name for the reward function (works for classes, functions, lambdas)
+ reward_fn_name = getattr(
+ reward_fn, "__name__", reward_fn.__class__.__name__
+ )
+ reward_breakdown[reward_fn_name] = reward
+
+ # log per fn reward and avg total
+ record_metric(
+ f"reward/evaluate_response/avg_{reward_fn_name}_reward",
+ reward,
+ Reduce.MEAN,
+ )
+ record_metric(
+ f"reward/evaluate_response/std_{reward_fn_name}_reward",
+ reward,
+ Reduce.STD,
+ )
+
+ record_metric(
+ "reward/evaluate_response/avg_total_reward",
+ reward,
+ Reduce.MEAN,
+ )
+
+ avg_reward: float = total_rewards / len(self.reward_functions)
+ return reward_breakdown, avg_reward
diff --git a/src/forge/rl/types.py b/src/forge/rl/types.py
new file mode 100644
index 000000000..e54743c98
--- /dev/null
+++ b/src/forge/rl/types.py
@@ -0,0 +1,84 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass
+from typing import Any
+
+import torch
+import torch.nn.functional as F
+
+from forge.actors.generator import Generator
+from forge.data_models.completion import Completion
+
+
+@dataclass
+class Episode:
+ episode_id: str
+ pad_id: int
+ request_len: int
+ response_len: int
+ target: Any | None = None
+ request: str | None = None
+ response: str | None = None
+ # Processed data
+ completion: Completion | None = None
+ ref_logprobs: torch.Tensor | None = None
+ reward: float | None = None
+ reward_breakdown: dict[str, float] | None = None
+ advantage: float | None = None
+
+ @property
+ def policy_version(self) -> int | None:
+ return self.completion.generator_version
+
+ @property
+ def request_tensor(self) -> torch.Tensor:
+ tensor: torch.Tensor = self.completion.prompt_ids.to(torch.long)
+ if tensor.shape[0] < self.request_len: # left pad
+ diff = self.request_len - tensor.shape[0]
+ tensor = F.pad(tensor, (diff, 0), value=self.pad_id)
+ return tensor
+
+ @property
+ def response_tensor(self) -> torch.Tensor:
+ tensor: torch.Tensor = self.completion.token_ids.to(torch.long)
+ if tensor.shape[0] < self.response_len: # right pad
+ diff = self.response_len - tensor.shape[0]
+ tensor = F.pad(tensor, (0, diff), value=self.pad_id)
+ return tensor
+
+ def to_dict(self, exclude: list[str] | None = None) -> dict[str, Any]:
+ """Convert episode to dict, optionally excluding specified fields."""
+ result = {
+ "episode_id": self.episode_id,
+ "policy_version": self.policy_version,
+ "prompt": self.request,
+ "response": self.response,
+ "target": str(self.target),
+ "reward": self.reward,
+ "advantage": self.advantage,
+ "request_len": self.request_len,
+ "response_len": self.response_len,
+ "pad_id": self.pad_id,
+ "ref_logprobs": self.ref_logprobs,
+ "completion": self.completion,
+ }
+
+ if self.reward_breakdown is not None and "reward_breakdown" not in exclude:
+ result.update(self.reward_breakdown)
+
+ if exclude:
+ for key in exclude:
+ result.pop(key, None)
+
+ return result
+
+
+# Represents the group (G) of episodes in GRPO
+Group = list[Episode]
+
+# Represents the Policy Model to collect data from
+Policy = Generator
diff --git a/src/forge/types.py b/src/forge/types.py
index f79e3ef2c..8ba17bd01 100644
--- a/src/forge/types.py
+++ b/src/forge/types.py
@@ -15,15 +15,6 @@ class Message(TypedDict):
tools: dict[str, Any] | None
-@dataclass
-class ForgeEnvInfo:
- """Environment info returned with observations."""
-
- episode_id: str | None = None
- step_count: int = 0
- metadata: dict | None = None
-
-
@dataclass(kw_only=True)
class Observation:
"""Base class for environment observations.
@@ -44,50 +35,6 @@ class Observation:
metadata: dict[str, Any] = field(default_factory=dict)
-@dataclass(kw_only=True)
-class Action:
- """Base class for environment actions.
-
- Contract:
- - Should contain all information needed to execute a step in the environment
- - Should be serializable/deserializable
- - Should be immutable (or treated as such)
-
- Args:
- metadata: Additional data that may be useful for logging, debugging, or transforms
- """
-
- metadata: dict[str, Any] = field(default_factory=dict)
-
-
-@dataclass
-class Trajectory:
- """A trajectory containing a sequence of states, actions, etc."""
-
- policy_version: int
- states: list[Observation] = field(default_factory=list)
- actions: list[Action] = field(default_factory=list)
-
- def __post_init__(self):
- assert self.policy_version >= 0
-
-
-@dataclass(kw_only=True)
-class State:
- """Base class for environment state.
-
- Contract:
- - Should contain all information needed to restore the environment
- - Should be serializable/deserializable
- - May contain information not exposed in observations
-
- Args:
- metadata: Additional state information that may be useful for debugging or analysis
- """
-
- metadata: dict[str, Any] = field(default_factory=dict)
-
-
class Launcher(Enum):
MAST = "mast"
SLURM = "slurm"
@@ -95,7 +42,17 @@ class Launcher(Enum):
@dataclass
class ProcessConfig:
- """A proc_mesh config for the torchx scheduler."""
+ """A configuration for allocating Monarch ProcMeshes.
+
+ Args:
+ procs (int): Number of processes to launch for each replica of the service.
+ with_gpus (bool, optional): Whether to allocate GPUs for the service processes.
+ hosts (int | None, optional): Number of hosts to allocate for each replica.
+ If this is set to None, it will use the local host.
+ If this is set to a positive integer, it will run on a remote host.
+ mesh_name (str | None, optional): Name of the mesh to use for the proc_mesh.
+
+ """
procs: int = 1
with_gpus: bool = False
@@ -105,13 +62,15 @@ class ProcessConfig:
@dataclass
class ServiceConfig:
- """
- A service config.
+ """The configuration for a Forge service.
+
Args:
procs (int): Number of processes to launch for each replica of the service.
num_replicas (int): Number of replicas to launch for the service.
with_gpus (bool, optional): Whether to allocate GPUs for the service processes.
hosts (int | None, optional): Number of hosts to allocate for each replica.
+ If this is set to None, it will use the local host.
+ If this is set to a positive integer, it will run on a remote host.
health_poll_rate (float, optional): Frequency (in seconds) to poll for health status.
replica_max_concurrent_requests (int, optional): Maximum number of concurrent requests per replica.
return_first_rank_result (bool, optional): Whether to auto-unwrap ValueMesh to the first rank's result.
@@ -121,7 +80,6 @@ class ServiceConfig:
num_replicas: int
with_gpus: bool = False
hosts: int | None = None
- # ServiceConfig-specific fields
health_poll_rate: float = 0.2
replica_max_concurrent_requests: int = 10
return_first_rank_result: bool = True
@@ -129,6 +87,7 @@ class ServiceConfig:
def to_process_config(self) -> ProcessConfig:
"""Extract ProcessConfig from this ServiceConfig.
+
Maps procs to procs for ProcessConfig.
"""
return ProcessConfig(
@@ -147,9 +106,20 @@ class LauncherConfig:
"""A launcher config for the scheduler."""
launcher: Launcher
- job_name: str
- services: dict[str, ServiceConfig]
- actors: dict[str, ProcessConfig]
+ job_name: str = ""
+ services: dict[str, ServiceConfig] = field(default_factory=dict)
+ actors: dict[str, ProcessConfig] = field(default_factory=dict)
+ cpu: int | None = None # CPUs per node (required for SLURM, can get with sinfo)
+ memMB: int | None = ( # noqa: N815
+ None # Memory in MB per node (required for SLURM, can get with sinfo)
+ )
+ gpu: int = 8 # GPUs per node (required for SLURM, can get with sinfo)
+ account: str = ""
+ qos: str = ""
+
+ def __post_init__(self):
+ if isinstance(self.launcher, str):
+ self.launcher = Launcher(self.launcher)
@dataclass
diff --git a/src/forge/util/__init__.py b/src/forge/util/__init__.py
index 5fb03b0f9..552c43dfc 100644
--- a/src/forge/util/__init__.py
+++ b/src/forge/util/__init__.py
@@ -5,12 +5,10 @@
# LICENSE file in the root directory of this source tree.
from .distributed import get_world_size_and_rank
from .logging import get_logger, log_once, log_rank_zero
-from .metric_logging import get_metric_logger
__all__ = [
"get_world_size_and_rank",
"get_logger",
"log_once",
"log_rank_zero",
- "get_metric_logger",
]
diff --git a/src/forge/util/_shared_tensor.py b/src/forge/util/_shared_tensor.py
new file mode 100644
index 000000000..18a7d65e6
--- /dev/null
+++ b/src/forge/util/_shared_tensor.py
@@ -0,0 +1,440 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import annotations
+
+import logging
+
+import uuid
+from dataclasses import dataclass
+from multiprocessing import shared_memory
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+@dataclass
+class SharedTensorHandle:
+ shm_name: str
+ shape: Tuple[int, ...]
+ dtype: str
+
+ def to_shared_tensor(self) -> SharedTensor:
+ """
+ Create a SharedTensor from this handle.
+
+ Returns:
+ SharedTensor instance attached to the shared memory referenced by this handle
+ """
+ return SharedTensor(handle=self)
+
+ def drop(self) -> None:
+ """
+ Unlink the shared memory segment.
+
+ This marks the shared memory for deletion. The actual memory will be freed
+ once all processes have closed their handles to it.
+
+ Note: This only unlinks, it does not close any handles. Processes that have
+ opened this shared memory should call close() on their SharedTensor instances.
+ """
+ try:
+ # Attach to the shared memory just to unlink it
+ shm = shared_memory.SharedMemory(name=self.shm_name)
+ shm.close()
+ shm.unlink()
+ except Exception:
+ pass
+
+
+class SharedTensor:
+ """
+ Wrapper class for tensors backed by shared memory.
+
+ This class provides a way to share tensors between processes using POSIX shared memory.
+ It's designed for efficient inter-process tensor communication without copying data.
+
+ Ownership and Lifecycle Model:
+ ------------------------------
+ 1. **Creator process**:
+ - Creates SharedTensor with tensor data or empty
+ - Gets a handle via get_handle() to pass to other processes
+ - **MUST** call close() after getting handle to release its reference
+ - **SHOULD** call drop()/unlink() when all processes are done
+
+ 2. **Receiver processes**:
+ - Receive SharedTensorHandle (via RPC, pickle, etc.)
+ - Create SharedTensor from handle: SharedTensor(handle=handle)
+ - Use the tensor: handle.to_shared_tensor().tensor
+ - **MUST** call close() when done using the tensor
+
+ 3. **Cleanup**:
+ - close(): Closes this process's file descriptor/handle
+ - drop()/unlink(): Marks shared memory for deletion (call once, from any process)
+ - Actual memory is freed when all processes have closed AND unlink is called
+
+ Memory Leak Prevention:
+ ----------------------
+ - **DO NOT** rely on __del__ for cleanup! Python GC is unpredictable.
+ - **ALWAYS** explicitly call close() when done with a SharedTensor
+ - **ALWAYS** call drop() on handles when sharing is complete
+ - Use context manager (with statement) for automatic cleanup
+ - After close(), accessing .tensor will raise RuntimeError
+ - After close(), getting handle will raise RuntimeError
+
+ Closed State Behavior:
+ ---------------------
+ - Once close() is called, the SharedTensor enters a closed state
+ - Accessing .tensor after close() raises RuntimeError
+ - Calling get_handle() after close() raises RuntimeError
+ - You can check the state with the .is_closed property
+ - close() and drop() are idempotent (safe to call multiple times)
+
+ Important Warning:
+ ------------------
+ If you hold a reference to the tensor BEFORE calling close(), that
+ reference becomes INVALID after close():
+ t = shared.tensor # Get reference
+ shared.close() # Close SharedTensor - unmaps memory
+ t.sum() # SEGFAULT! The memory is now invalid
+
+ After close(), the shared memory mapping is unmapped, so ALL references
+ to the tensor (including cached ones) point to invalid memory. Accessing
+ them will cause segmentation faults or undefined behavior.
+
+ Always ensure you're done with the tensor before calling close().
+
+ Example Usage:
+ -------------
+ # Creator process
+ tensor = torch.randn(100, 100)
+ shared = SharedTensor(tensor=tensor)
+ handle = shared.get_handle()
+ shared.close() # Close creator's reference
+ # ... send handle to other process via RPC ...
+ handle.drop() # Unlink after all receivers have it
+
+ # Receiver process
+ # ... receive handle via RPC ...
+ shared = SharedTensor(handle=handle)
+ result = shared.tensor.sum() # Use the tensor
+ shared.close() # Close receiver's reference
+
+ # Or use context manager (recommended)
+ with SharedTensor(handle=handle) as shared:
+ result = shared.tensor.sum()
+ # Automatically closed
+ """
+
+ def __init__(
+ self,
+ *,
+ tensor: Optional[torch.Tensor] = None,
+ handle: Optional[SharedTensorHandle] = None,
+ ):
+ if tensor is not None:
+ self._create_from_tensor(tensor)
+ elif handle is not None:
+ self._create_from_handle(handle)
+ else:
+ raise ValueError("Must provide either tensor or handle")
+
+ @classmethod
+ def empty(
+ cls,
+ shape: Union[Tuple[int, ...], torch.Size],
+ dtype: torch.dtype = torch.float32,
+ ):
+ """
+ Create an empty tensor directly in shared memory (no copy/allocation overhead)
+
+ Args:
+ shape: Shape of the tensor
+ dtype: PyTorch dtype (supports bfloat16, float32, etc.)
+
+ Returns:
+ SharedTensor instance with uninitialized data
+ """
+ instance = cls.__new__(cls)
+ instance._create_empty(shape, dtype)
+ return instance
+
+ @classmethod
+ def zeros(
+ cls,
+ shape: Union[Tuple[int, ...], torch.Size],
+ dtype: torch.dtype = torch.float32,
+ ):
+ """
+ Create a zero-initialized tensor in shared memory
+
+ Args:
+ shape: Shape of the tensor
+ dtype: PyTorch dtype
+
+ Returns:
+ SharedTensor instance with zeros
+ """
+ shared_tensor = cls.empty(shape, dtype)
+ shared_tensor.tensor.zero_()
+ return shared_tensor
+
+ @classmethod
+ def ones(
+ cls,
+ shape: Union[Tuple[int, ...], torch.Size],
+ dtype: torch.dtype = torch.float32,
+ ):
+ """
+ Create a ones-initialized tensor in shared memory
+
+ Args:
+ shape: Shape of the tensor
+ dtype: PyTorch dtype
+
+ Returns:
+ SharedTensor instance with ones
+ """
+ shared_tensor = cls.empty(shape, dtype)
+ shared_tensor.tensor.fill_(1)
+ return shared_tensor
+
+ def _create_empty(self, shape, dtype):
+ """Initialize with empty tensor in shared memory"""
+ # Initialize lifecycle state
+ self._closed = False
+ self._tensor_cache = None
+
+ # Store metadata
+ self._shape = tuple(shape) if not isinstance(shape, tuple) else shape
+ self._dtype = dtype
+ self._dtype_str = str(dtype)
+
+ # Calculate size
+ element_size = torch.tensor([], dtype=dtype).element_size()
+ total_elements = int(np.prod(self._shape))
+ byte_size = total_elements * element_size
+
+ # Create shared memory (uninitialized - fast!)
+ shm_name = f"shared_tensor_{uuid.uuid4().hex}"
+ self._shm = shared_memory.SharedMemory(
+ create=True, size=byte_size, name=shm_name
+ )
+ self._shm_name = shm_name
+
+ def _create_from_tensor(self, tensor):
+ """Initialize from an existing tensor"""
+ # Initialize lifecycle state
+ self._closed = False
+ self._tensor_cache = None
+
+ tensor = tensor.contiguous()
+
+ # Store metadata
+ self._shape = tuple(tensor.shape)
+ self._dtype = tensor.dtype
+ self._dtype_str = str(tensor.dtype)
+
+ # Create shared memory
+ byte_size = tensor.numel() * tensor.element_size()
+ shm_name = f"shared_tensor_{uuid.uuid4().hex}"
+
+ self._shm = shared_memory.SharedMemory(
+ create=True, size=byte_size, name=shm_name
+ )
+ self._shm_name = shm_name
+
+ # Copy data as raw bytes
+ raw_bytes = tensor.view(torch.uint8).view(-1).cpu().contiguous().numpy()
+ self._shm.buf[:byte_size] = raw_bytes
+ del raw_bytes # Explicitly free the intermediate numpy array
+
+ def _create_from_handle(self, handle: SharedTensorHandle):
+ """Initialize from a handle"""
+ # Initialize lifecycle state
+ self._closed = False
+ self._tensor_cache = None
+
+ self._shm_name = handle.shm_name
+ self._shape = handle.shape
+ self._dtype_str = handle.dtype
+ self._dtype = self._parse_dtype(self._dtype_str)
+
+ # Attach to existing shared memory\
+ self._shm = shared_memory.SharedMemory(name=self._shm_name)
+
+ def _create_tensor_view(self):
+ """Create tensor view of shared memory."""
+ element_size = torch.tensor([], dtype=self._dtype).element_size()
+ total_elements = int(np.prod(self._shape))
+ byte_size = total_elements * element_size
+
+ # Create numpy array that shares the buffer
+ np_array = np.ndarray(shape=(byte_size,), dtype=np.uint8, buffer=self._shm.buf)
+ # Create torch tensor from numpy (shares memory)
+ uint8_tensor = torch.from_numpy(np_array)
+ tensor = uint8_tensor.view(self._dtype).reshape(self._shape)
+
+ # Keep the np array alive
+ tensor._forge_np_array = np_array
+
+ return tensor
+
+ def _parse_dtype(self, dtype_str):
+ """Parse dtype string"""
+ dtype_str = dtype_str.replace("torch.", "")
+ return getattr(torch, dtype_str)
+
+ def get_handle(self):
+ """
+ Get a picklable handle to share this SharedTensor with other processes.
+
+ Returns:
+ SharedTensorHandle: A lightweight handle that can be pickled and sent to other processes
+
+ Raises:
+ RuntimeError: If called after close() has been called
+ """
+ if self._closed:
+ raise RuntimeError(
+ "Cannot get handle after close(). Get the handle before closing."
+ )
+ return SharedTensorHandle(
+ shm_name=self._shm_name,
+ shape=self._shape,
+ dtype=self._dtype_str,
+ )
+
+ @property
+ def tensor(self):
+ """
+ Get the underlying tensor.
+
+ Returns:
+ torch.Tensor: View into the shared memory
+
+ Raises:
+ RuntimeError: If accessed after close() has been called
+ """
+ if self._closed:
+ raise RuntimeError(
+ "Cannot access tensor after close(). The SharedTensor has been closed."
+ )
+ if self._tensor_cache is None:
+ self._tensor_cache = self._create_tensor_view()
+ return self._tensor_cache
+
+ def copy_from(self, source_tensor):
+ """
+ Copy data from another tensor into this shared tensor
+ Useful when you create empty tensor first, then fill it
+
+ Args:
+ source_tensor: Source tensor to copy from
+ """
+ if source_tensor.shape != self._shape:
+ raise ValueError(f"Shape mismatch: {source_tensor.shape} vs {self._shape}")
+ # Copy data
+ self.tensor.copy_(source_tensor)
+
+ def clone(self):
+ """Create a new SharedTensor with copied data"""
+ new_shared = SharedTensor.empty(self._shape, self._dtype)
+ new_shared.tensor.copy_(self.tensor)
+ return new_shared
+
+ def close(self):
+ """
+ Close this process's handle to the shared memory.
+
+ This should be called when this process is done using the shared memory.
+ The shared memory will persist until all processes have closed their handles
+ and someone calls unlink().
+
+ After calling close(), this SharedTensor object should not be used anymore.
+ Accessing the tensor property after close() will raise a RuntimeError.
+
+ This method is idempotent - calling it multiple times is safe.
+
+ Note: If you hold a reference to the tensor before calling close(),
+ that reference will remain valid, but new accesses via shared.tensor
+ will raise an error.
+ """
+ if self._closed:
+ return # Already closed, nothing to do
+
+ self._closed = True
+ self._tensor_cache = None # Release tensor and numpy array references
+
+ try:
+ self._shm.close()
+ except Exception as e:
+ logger.error(f"Error closing shared memory {self._shm_name}: {e}")
+
+ def drop(self):
+ """
+ Close and unlink the shared memory.
+
+ This method first closes this process's handle (if not already closed),
+ then marks the shared memory for deletion. The actual memory will be freed
+ once all processes have closed their handles.
+
+ This method is idempotent - calling it multiple times is safe.
+
+ Note:
+ This should be called when the shared tensor is no longer needed.
+ Failing to call this method may result in shared memory leaks.
+ """
+ # Close first to set _closed flag and release cache
+ self.close()
+
+ # Then unlink
+ try:
+ self._shm.unlink()
+ except Exception as e:
+ raise RuntimeError(
+ f"Error unlinking shared memory {self._shm_name}: {e}"
+ ) from e
+
+ @property
+ def is_closed(self) -> bool:
+ """
+ Check if this SharedTensor has been closed.
+
+ Returns:
+ bool: True if close() has been called, False otherwise
+ """
+ return self._closed
+
+ def __enter__(self):
+ """Context manager entry."""
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """Context manager exit - closes the shared memory handle."""
+ self.close()
+ return False
+
+ def __del__(self):
+ """
+ Best-effort cleanup on garbage collection.
+
+ WARNING: Do NOT rely on __del__ for cleanup! Python's garbage collector
+ may not call __del__ promptly or at all, which can cause memory leaks.
+ Always explicitly call close() when done with the SharedTensor.
+
+ This __del__ is only a safety net for cases where explicit cleanup is missed.
+ """
+ # Only close if the object was fully initialized
+ if hasattr(self, "_closed"):
+ self.close()
+
+ def __repr__(self):
+ return f"SharedTensor(shape={self._shape}, dtype={self._dtype}, shm_name={self._shm_name})"
diff --git a/src/forge/util/checkpoint.py b/src/forge/util/checkpoint.py
new file mode 100644
index 000000000..771c5876b
--- /dev/null
+++ b/src/forge/util/checkpoint.py
@@ -0,0 +1,31 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import time
+
+import torchstore as ts
+
+from forge.actors._torchstore_utils import (
+ get_dcp_whole_state_dict_key,
+ get_param_prefix,
+)
+
+
+async def drop_weights(version: int):
+ print(f"Dropping weights @ version {version}")
+ start_time = time.perf_counter()
+ prefix = get_param_prefix(version)
+ matching_keys = await ts.keys(prefix)
+ # TODO: once we have something like `get_meta()` in torchstore, we can just
+ # query the type of the object instead of relying on keys.
+ dcp_key = get_dcp_whole_state_dict_key(version)
+ if dcp_key in matching_keys:
+ dcp_handle = await ts.get(dcp_key)
+ dcp_handle.drop()
+ for key in matching_keys:
+ await ts.delete(key)
+ elapsed = time.perf_counter() - start_time
+ print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds")
diff --git a/src/forge/cli/config.py b/src/forge/util/config.py
similarity index 94%
rename from src/forge/cli/config.py
rename to src/forge/util/config.py
index a5e35cefd..c93c4c575 100644
--- a/src/forge/cli/config.py
+++ b/src/forge/util/config.py
@@ -15,6 +15,12 @@
from omegaconf import DictConfig, OmegaConf
+# Add support for summing lists of numbers, e.g. ${sum:${max_req_tokens},${max_res_tokens}}
+OmegaConf.register_new_resolver("sum", lambda *args: sum(args), replace=True)
+
+# Add support for boolean negation, e.g. ${not:${compile}}
+OmegaConf.register_new_resolver("not", lambda x: not x, replace=True)
+
def _has_component(node: Any) -> bool:
"""Check if a node has a _component_ field."""
@@ -56,22 +62,22 @@ def _merge_yaml_and_cli_args(yaml_args: Namespace, cli_args: list[str]) -> DictC
cli args, respectively) and merges them into a single OmegaConf DictConfig.
If a cli arg overrides a yaml arg with a _component_ field, the cli arg can
- be specified with the parent field directly, e.g., model=torchtune.models.lora_llama2_7b
- instead of model._component_=torchtune.models.lora_llama2_7b. Nested fields within the
+ be specified with the parent field directly, e.g., model=my_module.models.my_model
+ instead of model._component_=my_module.models.my_model. Nested fields within the
component should be specified with dot notation, e.g., model.lora_rank=16.
Example:
>>> config.yaml:
>>> a: 1
>>> b:
- >>> _component_: torchtune.models.my_model
+ >>> _component_: my_module.models.my_model
>>> c: 3
- >>> tune full_finetune --config config.yaml b=torchtune.models.other_model b.c=4
+ >>> python main.py --config config.yaml b=my_module.models.other_model b.c=4
>>> yaml_args, cli_args = parser.parse_known_args()
>>> conf = _merge_yaml_and_cli_args(yaml_args, cli_args)
>>> print(conf)
- >>> {"a": 1, "b": {"_component_": "torchtune.models.other_model", "c": 4}}
+ >>> {"a": 1, "b": {"_component_": "my_module.models.other_model", "c": 4}}
Args:
yaml_args (Namespace): Namespace containing args from yaml file, components
diff --git a/src/forge/util/logging.py b/src/forge/util/logging.py
index e53218ccd..8a7c1c99d 100644
--- a/src/forge/util/logging.py
+++ b/src/forge/util/logging.py
@@ -4,33 +4,52 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# FIXME: remove this once wandb fixed this issue
+# https://github.com/wandb/wandb/issues/10890
+# Patch importlib.metadata.distributions before wandb imports it
+# to filter out packages with None metadata
+import importlib.metadata
+
+# Guard to ensure this runs only once
+if not hasattr(importlib.metadata, "_distributions_patched"):
+ _original_distributions = importlib.metadata.distributions
+
+ def _patched_distributions():
+ """Filter out distributions with None metadata"""
+ for distribution in _original_distributions():
+ if distribution.metadata is not None:
+ yield distribution
+
+ importlib.metadata.distributions = _patched_distributions
+ importlib.metadata._distributions_patched = True
+
import logging
from functools import lru_cache
-from typing import Optional, TypeVar
from torch import distributed as dist
-T = TypeVar("T", bound=type)
-
-def get_logger(level: Optional[str] = None) -> logging.Logger:
+def get_logger(level: str | None = None) -> logging.Logger:
"""
Get a logger with a stream handler.
Args:
- level (Optional[str]): The logging level. See https://docs.python.org/3/library/logging.html#levels for list of levels.
+ level (str | None): The logging level. See https://docs.python.org/3/library/logging.html#levels for list of levels.
Example:
>>> logger = get_logger("INFO")
>>> logger.info("Hello world!")
- INFO:torchtune.utils._logging:Hello world!
+ INFO:forge.util.logging: Hello world!
Returns:
logging.Logger: The logger.
"""
logger = logging.getLogger(__name__)
if not logger.hasHandlers():
- logger.addHandler(logging.StreamHandler())
+ handler = logging.StreamHandler()
+ formatter = logging.Formatter("%(levelname)s:%(name)s: %(message)s")
+ handler.setFormatter(formatter)
+ logger.addHandler(handler)
if level is not None:
level = getattr(logging, level.upper())
logger.setLevel(level)
diff --git a/src/forge/util/metric_logging.py b/src/forge/util/metric_logging.py
deleted file mode 100644
index 1141cfbd7..000000000
--- a/src/forge/util/metric_logging.py
+++ /dev/null
@@ -1,294 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-import os
-import sys
-import time
-from typing import Mapping, Optional, Union
-
-from forge.interfaces import MetricLogger
-from forge.types import Scalar
-from forge.util.distributed import get_world_size_and_rank
-
-
-def get_metric_logger(logger: str = "stdout", **log_config):
- return METRIC_LOGGER_STR_TO_CLS[logger](**log_config)
-
-
-class StdoutLogger(MetricLogger):
- """Logger to standard output.
-
- Args:
- freq (Union[int, Mapping[str, int]]):
- If int, all metrics will be logged at this frequency.
- If Mapping, calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
- """
-
- def __init__(self, freq: Union[int, Mapping[str, int]]):
- self._freq = freq
-
- def is_log_step(self, name: str, step: int) -> bool:
- """Returns true if the current step is a logging step.
-
- Args:
- name (str): metric name (for checking the freq for this metric)
- step (int): current step
- """
- if isinstance(self._freq, int):
- return step % self._freq == 0
- return step % self._freq[name] == 0
-
- def log(self, name: str, data: Scalar, step: int) -> None:
- """Log the metric if it is a logging step.
-
- Args:
- name (str): metric name
- data (Scalar): metric value
- step (int): current step
- """
- if not self.is_log_step(name, step):
- return
- print(f"Step {step} | {name}:{data}")
-
- def log_dict(self, metrics: Mapping[str, Scalar], step: int) -> None:
- """Log the metrics for which this is currently a logging step.
-
- Args:
- metrics (Mapping[str, Scalar]): dict of metric names and values
- step (int): current step
- """
- log_step_metrics = {
- name: value
- for name, value in metrics.items()
- if self.is_log_step(name, step)
- }
- if not log_step_metrics:
- return
-
- print(f"Step {step} | ", end="")
- for name, data in log_step_metrics.items():
- print(f"{name}:{data} ", end="")
- print("\n", end="")
-
- def close(self) -> None:
- sys.stdout.flush()
-
-
-class TensorBoardLogger(MetricLogger):
- """Logger for use w/ PyTorch's implementation of TensorBoard (https://pytorch.org/docs/stable/tensorboard.html).
-
- Args:
- freq (Union[int, Mapping[str, int]]):
- If int, all metrics will be logged at this frequency.
- If Mapping, calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
- log_dir (str): torch.TensorBoard log directory
- organize_logs (bool): If `True`, this class will create a subdirectory within `log_dir` for the current
- run. Having sub-directories allows you to compare logs across runs. When TensorBoard is
- passed a logdir at startup, it recursively walks the directory tree rooted at logdir looking for
- subdirectories that contain tfevents data. Every time it encounters such a subdirectory,
- it loads it as a new run, and the frontend will organize the data accordingly.
- Recommended value is `True`. Run `tensorboard --logdir my_log_dir` to view the logs.
- **kwargs: additional arguments
-
- Example:
- >>> from forge.util.metric_logging import TensorBoardLogger
- >>> logger = TensorBoardLogger(freq={"loss": 10}, log_dir="my_log_dir")
- >>> logger.log("my_metric", 1.0, 1)
- >>> logger.log_dict({"my_metric": 1.0}, 1)
- >>> logger.close()
-
- Note:
- This utility requires the tensorboard package to be installed.
- You can install it with `pip install tensorboard`.
- In order to view TensorBoard logs, you need to run `tensorboard --logdir my_log_dir` in your terminal.
- """
-
- def __init__(
- self,
- freq: Union[int, Mapping[str, int]],
- log_dir: str = "metrics_log",
- organize_logs: bool = True,
- **kwargs,
- ):
- from torch.utils.tensorboard import SummaryWriter
-
- self._freq = freq
- self._writer: Optional[SummaryWriter] = None
- _, rank = get_world_size_and_rank()
-
- # In case organize_logs is `True`, update log_dir to include a subdirectory for the
- # current run
- self.log_dir = (
- os.path.join(log_dir, f"run_{rank}_{time.time()}")
- if organize_logs
- else log_dir
- )
-
- # Initialize the log writer only if we're on rank 0.
- if rank == 0:
- self._writer = SummaryWriter(log_dir=self.log_dir)
-
- def is_log_step(self, name: str, step: int) -> bool:
- """Returns true if the current step is a logging step.
-
- Args:
- name (str): metric name (for checking the freq for this metric)
- step (int): current step
- """
- if isinstance(self._freq, int):
- return step % self._freq == 0
- return step % self._freq[name] == 0
-
- def log(self, name: str, data: Scalar, step: int) -> None:
- """Log the metric if it is a logging step.
-
- Args:
- name (str): metric name
- data (Scalar): metric value
- step (int): current step
- """
- if self._writer:
- self._writer.add_scalar(name, data, global_step=step, new_style=True)
-
- def log_dict(self, metrics: Mapping[str, Scalar], step: int) -> None:
- """Log the metrics for which this is currently a logging step.
-
- Args:
- metrics (Mapping[str, Scalar]): dict of metric names and values
- step (int): current step
- """
- for name, data in metrics.items():
- if self.is_log_step(name, step):
- self.log(name, data, step)
-
- def close(self) -> None:
- if self._writer:
- self._writer.close()
- self._writer = None
-
-
-class WandBLogger(MetricLogger):
- """Logger for use w/ Weights and Biases application (https://wandb.ai/).
- For more information about arguments expected by WandB, see https://docs.wandb.ai/ref/python/init.
-
- Args:
- freq (Union[int, Mapping[str, int]]):
- If int, all metrics will be logged at this frequency.
- If Mapping, calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
- log_dir (Optional[str]): WandB log directory.
- project (str): WandB project name. Default is `torchtune`.
- entity (Optional[str]): WandB entity name. If you don't specify an entity,
- the run will be sent to your default entity, which is usually your username.
- group (Optional[str]): WandB group name for grouping runs together. If you don't
- specify a group, the run will be logged as an individual experiment.
- **kwargs: additional arguments to pass to wandb.init
-
- Example:
- >>> from forge.util.metric_logging import WandBLogger
- >>> logger = WandBLogger(freq={"loss": 10}, log_dir="wandb", project="my_project")
- >>> logger.log("my_metric", 1.0, 1)
- >>> logger.log_dict({"my_metric": 1.0}, 1)
- >>> logger.close()
-
- Raises:
- ImportError: If ``wandb`` package is not installed.
-
- Note:
- This logger requires the wandb package to be installed.
- You can install it with `pip install wandb`.
- In order to use the logger, you need to login to your WandB account.
- You can do this by running `wandb login` in your terminal.
- """
-
- def __init__(
- self,
- freq: Union[int, Mapping[str, int]],
- project: str,
- log_dir: str = "metrics_log",
- entity: Optional[str] = None,
- group: Optional[str] = None,
- **kwargs,
- ):
- self._freq = freq
-
- try:
- import wandb
- except ImportError as e:
- raise ImportError(
- "``wandb`` package not found. Please install wandb using `pip install wandb` to use WandBLogger."
- ) from e
- self._wandb = wandb
-
- if not os.path.exists(log_dir):
- os.makedirs(log_dir)
-
- _, rank = get_world_size_and_rank()
- if self._wandb.run is None and rank == 0:
- # we check if wandb.init got called externally
- run = self._wandb.init(
- project=project,
- entity=entity,
- group=group,
- dir=log_dir,
- **kwargs,
- )
-
- if self._wandb.run:
- # define default x-axis (for latest wandb versions)
- if getattr(self._wandb, "define_metric", None):
- self._wandb.define_metric("step")
- self._wandb.define_metric("*", step_metric="step", step_sync=True)
-
- def is_log_step(self, name: str, step: int) -> bool:
- """Returns true if the current step is a logging step.
-
- Args:
- name (str): metric name (for checking the freq for this metric)
- step (int): current step
- """
- if isinstance(self._freq, int):
- return step % self._freq == 0
- return step % self._freq[name] == 0
-
- def log(self, name: str, data: Scalar, step: int) -> None:
- """Log the metric if it is a logging step.
-
- Args:
- name (str): metric name
- data (Scalar): metric value
- step (int): current step
- """
- if self._wandb.run and self.is_log_step(name, step):
- self._wandb.log({name: data, "step": step})
-
- def log_dict(self, metrics: Mapping[str, Scalar], step: int) -> None:
- """Log the metrics for which this is currently a logging step.
-
- Args:
- metrics (Mapping[str, Scalar]): dict of metric names and values
- step (int): current step
- """
- log_step_metrics = {
- name: value
- for name, value in metrics.items()
- if self.is_log_step(name, step)
- }
- if not log_step_metrics:
- return
-
- if self._wandb.run:
- self._wandb.log({**metrics, "step": step})
-
- def close(self) -> None:
- if hasattr(self, "_wandb") and self._wandb.run:
- self._wandb.finish()
-
-
-# TODO: replace with direct instantiation via a path to the class in the config
-METRIC_LOGGER_STR_TO_CLS = {
- "stdout": StdoutLogger,
- "tensorboard": TensorBoardLogger,
- "wandb": WandBLogger,
-}
diff --git a/src/forge/util/ops.py b/src/forge/util/ops.py
index 2eca1fdd1..59bff8570 100644
--- a/src/forge/util/ops.py
+++ b/src/forge/util/ops.py
@@ -6,69 +6,105 @@
import torch
import torch.nn.functional as F
+from torch.distributed.tensor import DTensor
-def selective_log_softmax(logits: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
+def compute_logprobs(
+ logits: torch.Tensor | DTensor,
+ input_ids: torch.Tensor,
+ temperature: float = 1.0,
+ align: bool = True,
+) -> torch.Tensor:
"""
- A memory-efficient implementation of the common `log_softmax -> gather` operation.
+ Computes the log probabilities of the input tokens given the model logits and temperature.
+ Always converts inputs to fp32 for numerical stability.
- This function is equivalent to the following naive implementation:
- ```python
- logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
- ```
+ This function handles two common usage patterns:
- Args:
- logits (`torch.Tensor`):
- Logits tensor of shape `(..., num_classes)`.
- index (`torch.Tensor`):
- Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output.
+ **Pattern 1: Pre-aligned logits (align=False)**
+ Use when logits are already aligned with input_ids, typically when you:
+ - Pass input_ids to the model: model(input_ids) -> logits
+ - The model outputs logits[i] that predict target_ids[i]
+ - logits.shape[1] == input_ids.shape[1]
- Returns:
- `torch.Tensor`:
- Gathered log probabilities with the same shape as `index`.
- """
- if logits.dtype in [torch.float32, torch.float64]:
- selected_logits = torch.gather(
- logits, dim=-1, index=index.unsqueeze(-1)
- ).squeeze(-1)
- # loop to reduce peak mem consumption
- logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
- per_token_logps = (
- selected_logits - logsumexp_values
- ) # log_softmax(x_i) = x_i - logsumexp(x)
- else:
- # logsumexp approach is unstable with bfloat16, fall back to slightly less efficient approach
- per_token_logps = []
- for row_logits, row_labels in zip(
- logits, index
- ): # loop to reduce peak mem consumption
- row_logps = F.log_softmax(row_logits, dim=-1)
- row_per_token_logps = row_logps.gather(
- dim=-1, index=row_labels.unsqueeze(-1)
- ).squeeze(-1)
- per_token_logps.append(row_per_token_logps)
- per_token_logps = torch.stack(per_token_logps)
- return per_token_logps
+ Example:
+ >>> input_ids = torch.tensor([[1, 2, 3, 4]]) # Model input
+ >>> target_ids = torch.tensor([[2, 3, 4, 5]]) # Shifted by 1 (next-token prediction)
+ >>> logits = model(input_ids) # Shape: [1, 4, vocab_size]
+ >>> # logits already aligned: logits[:, i] predicts target_ids[:, i]
+ >>> logprobs = compute_logprobs(logits, target_ids, align=False)
+ **Pattern 2: Full-sequence logits needing alignment (align=True, default)**
+ Use when you have logits for the full sequence but only want log probs for a subset
+ (e.g., just the response tokens, not the prompt). The function will:
+ - Slice logits to match the length of input_ids
+ - Take logits[:, -len(input_ids)-1:-1] to get positions that predict input_ids
-def compute_logprobs(
- logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
-) -> torch.Tensor:
- """
- Computes the log probabilities of the input tokens given the model logits and temperature.
+ Example:
+ >>> # Full sequence passed to model: [prompt + response]
+ >>> full_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6]]) # Prompt + response
+ >>> logits = model(full_input_ids) # Shape: [1, 6, vocab_size]
+ >>> # Only want log probs for response tokens
+ >>> response_tokens = torch.tensor([[4, 5, 6]]) # Just the response
+ >>> logprobs = compute_logprobs(logits, response_tokens, align=True)
+ >>> # Function slices logits[:, -4:-1] to get logits that predict tokens [4, 5, 6]
+
+ The alignment logic ensures that when you have a full sequence but only want log
+ probabilities for the response portion, you don't need to re-run the model. This
+ is a key optimization in RL training where the prompt remains constant.
+
+ **Tensor Parallelism Support:**
+ When logits is a DTensor sharded on the vocab dimension (e.g., from tensor parallel
+ training), wrap calls to this function with `loss_parallel()` context:
+
+ >>> from torch.distributed.tensor.parallel import loss_parallel
+ >>> with loss_parallel():
+ ... logprobs = compute_logprobs(logits, input_ids)
+
+ The `loss_parallel` context ensures F.cross_entropy works correctly with
+ vocab-sharded DTensors without needing to gather the full tensor.
Args:
logits (`torch.Tensor`):
The model output logits of shape `(batch_size, sequence_length, vocab_size)`.
+ Can be a regular Tensor or a DTensor (when using with loss_parallel context).
input_ids (`torch.Tensor`):
- The input token ids of shape `(batch_size, target_sequence_length)`.
+ The target token ids of shape `(batch_size, target_sequence_length)`.
+ These are the tokens for which you want to compute log probabilities.
temperature (`float`, *optional*, defaults to 1.0):
The temperature value for scaling logits before computing log probabilities.
+ Higher values make the distribution more uniform, lower values more peaked.
+ align (`bool`, *optional*, defaults to True):
+ If True (default), align logits with input_ids by slicing to extract the
+ relevant positions from a longer sequence (Pattern 2).
+ If False, assume logits are already aligned with input_ids (Pattern 1).
+
+ Returns:
+ torch.Tensor: Log probabilities of shape `(batch_size, target_sequence_length)`.
+ Each element [b, i] is the log probability of input_ids[b, i] given the
+ corresponding logits.
+ Note:
+ This function uses cross_entropy instead of log_softmax + gather for better
+ numerical stability, especially important for fp16/bf16 training.
"""
- # Ignore the last token from logits because it predicts the next token (-1)
- # And align logits with the input tokens length.
- logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device)
+ # Align logits with input_ids if requested
+ if align:
+ # Ignore the last token from logits because it predicts the next token (-1)
+ # And align logits with the input tokens length.
+ logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device)
+
scaled_logits = logits / temperature
- logprobs = selective_log_softmax(scaled_logits, input_ids)
- return logprobs
+
+ # Cast up to fp32 for numerical stability
+ scaled_logits_fp32 = scaled_logits.float()
+
+ # get per-token log probs
+ batch_size, seq_len, vocab_size = scaled_logits_fp32.shape
+ logprobs = -F.cross_entropy(
+ scaled_logits_fp32.reshape(-1, vocab_size),
+ input_ids.reshape(-1).long(),
+ reduction="none",
+ )
+
+ return logprobs.reshape(batch_size, seq_len)
diff --git a/src/forge/util/weight_verification.py b/src/forge/util/weight_verification.py
new file mode 100644
index 000000000..aa98d7df1
--- /dev/null
+++ b/src/forge/util/weight_verification.py
@@ -0,0 +1,217 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Utilities for verifying model weight updates during training."""
+
+import logging
+from dataclasses import dataclass
+from typing import Any
+
+import torch
+import torch.nn as nn
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class WeightSnapshot:
+ """Snapshot of model weights at a specific point in time."""
+
+ params: dict[str, torch.Tensor]
+ version: int | None = None
+ metadata: dict[str, Any] | None = None
+
+ @classmethod
+ def from_model(
+ cls, model: nn.Module, version: int | None = None, device: str = "cpu"
+ ) -> "WeightSnapshot":
+ """Create a snapshot of model parameters.
+
+ Args:
+ model: PyTorch model to snapshot
+ version: Optional version identifier
+ device: Device to store snapshot tensors (default: cpu)
+
+ Returns:
+ WeightSnapshot containing detached copies of all parameters
+ """
+ params = {}
+ for name, param in model.named_parameters():
+ params[name] = param.detach().to(device).clone()
+
+ return cls(params=params, version=version)
+
+
+@dataclass
+class WeightVerificationResult:
+ """Result of weight verification check."""
+
+ weights_changed: bool
+ num_params_checked: int
+ num_params_changed: int
+ num_params_unchanged: int
+ num_params_skipped: int
+ changed_params: list[str]
+ unchanged_params: list[str]
+ skipped_params: list[str]
+ max_delta: float | None = None
+ mean_delta: float | None = None
+
+ def __str__(self) -> str:
+ status = "✅ CHANGED" if self.weights_changed else "⚠️ UNCHANGED"
+ max_delta = f"{self.max_delta:.6e}" if self.max_delta is not None else "N/A"
+ mean_delta = f"{self.mean_delta:.6e}" if self.mean_delta is not None else "N/A"
+
+ return (
+ f"Weight Verification {status}:\n"
+ f" Checked: {self.num_params_checked}\n"
+ f" Changed: {self.num_params_changed}\n"
+ f" Unchanged: {self.num_params_unchanged}\n"
+ f" Skipped: {self.num_params_skipped}\n"
+ f" Max delta: {max_delta}\n"
+ f" Mean delta: {mean_delta}"
+ )
+
+
+def verify_weights_changed(
+ prev_snapshot: WeightSnapshot,
+ current_model: nn.Module,
+ atol: float = 1e-6,
+ rtol: float = 1e-5,
+ skip_non_float: bool = True,
+ verbose: bool = False,
+) -> WeightVerificationResult:
+ """Verify that model weights have changed compared to a previous snapshot.
+
+ This is a more robust verification than simple parameter hashing, as it:
+ - Checks each parameter individually
+ - Uses proper floating point comparison (torch.allclose)
+ - Provides detailed information about which parameters changed
+ - Computes statistics about the magnitude of changes
+
+ Args:
+ prev_snapshot: Previous weight snapshot to compare against
+ current_model: Current model to check
+ atol: Absolute tolerance for considering weights unchanged
+ rtol: Relative tolerance for considering weights unchanged
+ skip_non_float: Whether to skip non-floating point parameters
+ verbose: Whether to log detailed information
+
+ Returns:
+ WeightVerificationResult with detailed information about changes
+ """
+ changed_params = []
+ unchanged_params = []
+ skipped_params = []
+ deltas = []
+
+ for name, param in current_model.named_parameters():
+ if skip_non_float and not torch.is_floating_point(param):
+ skipped_params.append(name)
+ if verbose:
+ logger.info(f"Skipping non-float param: {name}")
+ continue
+
+ if name not in prev_snapshot.params:
+ logger.warning(f"Parameter {name} not found in previous snapshot")
+ skipped_params.append(name)
+ continue
+
+ prev_param = prev_snapshot.params[name]
+ curr_param = param.detach().cpu()
+
+ # Check if parameters are close (i.e., unchanged)
+ is_close = torch.allclose(prev_param, curr_param, atol=atol, rtol=rtol)
+
+ if is_close:
+ unchanged_params.append(name)
+ else:
+ changed_params.append(name)
+ # Compute delta for statistics
+ delta = (curr_param - prev_param).abs().max().item()
+ deltas.append(delta)
+
+ if verbose:
+ logger.info(
+ f"Parameter {name} changed - max delta: {delta:.6e}, "
+ f"mean delta: {(curr_param - prev_param).abs().mean().item():.6e}"
+ )
+
+ # Compute statistics
+ max_delta = max(deltas) if deltas else 0
+ mean_delta = sum(deltas) / len(deltas) if deltas else 0
+
+ result = WeightVerificationResult(
+ weights_changed=len(changed_params) > 0,
+ num_params_checked=len(changed_params) + len(unchanged_params),
+ num_params_changed=len(changed_params),
+ num_params_unchanged=len(unchanged_params),
+ num_params_skipped=len(skipped_params),
+ changed_params=changed_params,
+ unchanged_params=unchanged_params,
+ skipped_params=skipped_params,
+ max_delta=max_delta,
+ mean_delta=mean_delta,
+ )
+
+ logger.info(str(result))
+
+ return result
+
+
+def verify_weights_all_zeros(
+ current_model: nn.Module,
+ atol: float = 1e-4,
+ rtol: float = 1e-3,
+ skip_non_float: bool = True,
+ verbose: bool = False,
+) -> tuple[bool, list[str], list[str]]:
+ """Verify that all model parameters are zero.
+
+ Args:
+ current_model: Model to check
+ atol: Absolute tolerance
+ rtol: Relative tolerance
+ skip_non_float: Whether to skip non-floating point parameters
+ verbose: Whether to log detailed information
+
+ Returns:
+ Tuple of (all_zeros, zero_params, non_zero_params)
+ """
+ zero_params = []
+ non_zero_params = []
+
+ for name, param in current_model.named_parameters():
+ if skip_non_float and not torch.is_floating_point(param):
+ if verbose:
+ logger.info(f"Skipping non-float param: {name}")
+ continue
+
+ param_cpu = param.detach().cpu()
+ is_zero = torch.allclose(
+ torch.zeros_like(param_cpu), param_cpu, atol=atol, rtol=rtol
+ )
+
+ if is_zero:
+ zero_params.append(name)
+ else:
+ non_zero_params.append(name)
+ if verbose:
+ logger.info(
+ f"Parameter {name} is not zero - "
+ f"max: {param_cpu.abs().max().item():.6e}, "
+ f"mean: {param_cpu.abs().mean().item():.6e}"
+ )
+
+ all_zeros = len(non_zero_params) == 0
+
+ logger.info(
+ f"Zero check: {'✅ PASS' if all_zeros else '⚠️ FAIL'} - "
+ f"{len(zero_params)} zero, {len(non_zero_params)} non-zero"
+ )
+
+ return all_zeros, zero_params, non_zero_params
diff --git a/tests/README.md b/tests/README.md
index d02e49e78..148ab8711 100644
--- a/tests/README.md
+++ b/tests/README.md
@@ -5,8 +5,8 @@ This directory contains tests for the forge project, including unit tests and in
## Test Structure
- `unit_tests/`: Contains unit tests for individual components
-- `integration_tests.py`: Contains integration tests that test multiple components together
-- `integration_tests_h100.py`: Contains integration tests specifically designed for H100 GPUs, which utilize symmetric memory and float8.
+- `integration_tests/`: Contains integration tests that test multiple components together
+- `sandbox/`: Contains experimental adhoc scripts used for development and debugging
- `assets/`: Contains test assets and fixtures used by the tests
## Running Tests
@@ -21,50 +21,49 @@ pip install .[dev]
### Running Integration Tests
-To run the integration tests:
+To run all integration tests:
```bash
-python ./tests/integration_tests.py [--config_dir CONFIG_DIR] [--test TEST] [--ngpu NGPU]
+pytest -s tests/integration_tests/
```
-Arguments:
-- `output_dir`: (Required) Directory where test outputs will be stored
-- `--config_dir`: (Optional) Directory containing configuration files (default: "./torchtitan/models/llama3/train_configs")
-- `--test`: (Optional) Specific test to run, use test names from the `build_test_list()` function (default: "all")
-- `--ngpu`: (Optional) Number of GPUs to use for testing (default: 8)
+To run a specific integration test file:
-Examples:
```bash
-# Run all integration tests with 8 GPUs
-python ./tests/integration_tests.py ./test_output
+pytest -s tests/integration_tests/test_vllm_policy_correctness.py
+```
+
+To run a specific integration test function:
+
+```bash
+pytest -s tests/integration_tests/test_vllm_policy_correctness.py::test_same_output
+```
-# Run a specific test with 4 GPUs
-python ./tests/integration_tests.py ./test_output --test default --ngpu 4
+Integration tests support custom options defined in `conftest.py`:
+- `--config`: Path to YAML config file for sanity check tests
+- `--use_dcp`: Override the YAML config `trainer.use_dcp` field (true/false)
-# Run all tests with a custom config directory
-python ./tests/integration_tests.py ./test_output --config_dir ./my_configs
+Example with options:
+```bash
+pytest -s tests/integration_tests/ --config ./path/to/config.yaml --use_dcp true
```
### Running Unit Tests
-To run only the unit tests:
+To run all unit tests:
```bash
pytest -s tests/unit_tests/
```
-### Running Specific Unit Test Files
-
-To run a specific test file:
+To run a specific unit test file:
```bash
-pytest -s tests/unit_tests/test_job_config.py
+pytest -s tests/unit_tests/test_config.py
```
-### Running Specific Test Functions in Unit Tests
-
-To run a specific test function:
+To run a specific unit test function:
```bash
-pytest -s tests/unit_tests/test_job_config.py::TestJobConfig::test_command_line_args
+pytest -s tests/unit_tests/test_config.py::test_cache_hit_scenario
```
diff --git a/tests/conftest.py b/tests/conftest.py
index 3e46704b9..e299915b6 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -15,7 +15,7 @@
import pytest
-from forge.env_constants import FORGE_DISABLE_METRICS
+from forge.env import FORGE_DISABLE_METRICS
@pytest.fixture(autouse=True)
@@ -36,5 +36,5 @@ def test_real_metrics(mock_metrics_globally):
pass
"""
- monkeypatch.setenv(FORGE_DISABLE_METRICS, "true")
+ monkeypatch.setenv(FORGE_DISABLE_METRICS.name, "true")
return Mock()
diff --git a/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml b/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml
index 4d3a56d04..80e408f03 100644
--- a/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml
+++ b/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml
@@ -5,16 +5,17 @@ max_req_tokens: 512
max_res_tokens: 512
model: "Qwen/Qwen3-1.7B"
off_by_n: 1 # Off by one by default
+compile: true # Enable torch.compile for trainer, and CUDA graphs for vLLM
# Policy configuration
policy:
- engine_config:
+ engine_args:
model: ${model}
tensor_parallel_size: 1
pipeline_parallel_size: 1
- enforce_eager: false
- sampling_config:
+ enforce_eager: ${not:${compile}}
+ sampling_params:
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
@@ -34,13 +35,13 @@ trainer:
warmup_steps: 1
training:
local_batch_size: ${batch_size}
- seq_len: 2048
+ seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1000000
dtype: bfloat16
gc_freq: 1
compile:
- enable: false
+ enable: ${compile}
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
@@ -63,9 +64,11 @@ trainer:
# All resource allocations
services:
policy:
- procs: ${policy.engine_config.tensor_parallel_size}
+ procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 1
with_gpus: true
+
+actors:
trainer:
procs: 1
num_replicas: 1
diff --git a/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml b/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml
index 0ac915d2a..d4964baad 100644
--- a/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml
+++ b/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml
@@ -7,16 +7,17 @@ max_req_tokens: 512
max_res_tokens: 512
model: "Qwen/Qwen3-1.7B"
off_by_n: 1 # Off by one by default
+compile: true # Enable torch.compile for trainer, and CUDA graphs for vLLM
# Policy configuration
policy:
- engine_config:
+ engine_args:
model: ${model}
tensor_parallel_size: 4
pipeline_parallel_size: 1
- enforce_eager: false
- sampling_config:
+ enforce_eager: ${not:${compile}}
+ sampling_params:
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
@@ -36,13 +37,13 @@ trainer:
warmup_steps: 1
training:
local_batch_size: ${batch_size}
- seq_len: 2048
+ seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1000000
dtype: bfloat16
gc_freq: 1
compile:
- enable: false
+ enable: ${compile}
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
@@ -65,9 +66,11 @@ trainer:
# All resource allocations
services:
policy:
- procs: ${policy.engine_config.tensor_parallel_size}
+ procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 1
with_gpus: true
+
+actors:
trainer:
procs: 2
num_replicas: 1
diff --git a/tests/integration_tests/test_coder.py b/tests/integration_tests/test_coder.py
new file mode 100644
index 000000000..45a80ec4d
--- /dev/null
+++ b/tests/integration_tests/test_coder.py
@@ -0,0 +1,89 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Integration tests for forge.actors.coder.SandboxedPythonCoder.
+
+Requires enroot to be installed.
+
+"""
+
+import os
+import uuid
+
+import pytest
+
+from forge.actors.coder import SandboxedPythonCoder
+
+
+@pytest.mark.timeout(30)
+@pytest.mark.asyncio
+async def test_coder_runs_python():
+ """Integration test for SandboxedPythonCoder with real container execution."""
+ # Create unique names to avoid test conflicts
+ unique_id = str(uuid.uuid1())
+ container_name = f"test_sandbox_{unique_id}"
+ image_path = f"/tmp/python_test_{unique_id}.sqsh"
+
+ coder = None
+ try:
+ coder = await SandboxedPythonCoder.as_actor(
+ docker_image="docker://python:3.10",
+ sqsh_image_path=image_path,
+ container_name=container_name,
+ )
+
+ # Execute code
+ results, _ = await coder.execute.call_one(
+ code="print('hello world')",
+ )
+ print("Got results", results)
+ assert results == "hello world\n"
+
+ finally:
+ # Clean up resources
+ if coder:
+ await SandboxedPythonCoder.shutdown(coder)
+
+ # Clean up the image file
+ if os.path.exists(image_path):
+ os.unlink(image_path)
+
+
+@pytest.mark.timeout(30)
+@pytest.mark.asyncio
+async def test_coder_catches_error():
+ """Integration test for SandboxedPythonCoder with real container execution."""
+ # Create unique names to avoid test conflicts
+ unique_id = str(uuid.uuid1())
+ container_name = f"test_sandbox_{unique_id}"
+ image_path = f"/tmp/python_test_{unique_id}.sqsh"
+
+ coder = None
+ try:
+ print("starting test")
+ coder = await SandboxedPythonCoder.as_actor(
+ docker_image="docker://python:3.10",
+ sqsh_image_path=image_path,
+ container_name=container_name,
+ )
+ print("Got coder")
+
+ # Execute code
+ _, stderr = await coder.execute.call_one(
+ code="hello world",
+ )
+ print("got stderr", stderr)
+ assert "SyntaxError" in stderr
+
+ finally:
+ # Clean up resources
+ if coder:
+ await SandboxedPythonCoder.shutdown(coder)
+
+ # Clean up the image file
+ if os.path.exists(image_path):
+ os.unlink(image_path)
diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py
index 506fc5553..645718fcf 100644
--- a/tests/integration_tests/test_policy_update.py
+++ b/tests/integration_tests/test_policy_update.py
@@ -6,22 +6,37 @@
import asyncio
import logging
-from tempfile import TemporaryDirectory
+import shutil
+from pathlib import Path
+import monarch
import pytest
+import pytest_asyncio
import torch
import torchstore as ts
-from forge.actors.policy import Policy
+from forge.actors.generator import Generator
-from forge.actors.trainer import RLTrainer
-from forge.cli.config import resolve_hf_hub_paths
+from forge.actors.trainer import TitanTrainer
+from forge.controller.provisioner import init_provisioner
from forge.controller.service.service import uuid
+from forge.types import LauncherConfig, ProvisionerConfig
+from forge.util.config import resolve_hf_hub_paths
+from forge.util.weight_verification import (
+ verify_weights_all_zeros,
+ verify_weights_changed,
+ WeightSnapshot,
+)
from monarch.actor import endpoint
from omegaconf import DictConfig, OmegaConf
+# Workaround for monarch mesh shutdown exit code during teardown
+# Without this, proc_mesh.stop will raise exit code 1 after test completes
+monarch.actor.unhandled_fault_hook = lambda failure: None
+
+
requires_cuda = pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA not available",
@@ -35,15 +50,17 @@
"""
Run tests:
-pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \
- --config tests/integration_tests/artifacts/qwen3_1_7b_tp.yaml --use_dcp=false
+TORCHSTORE_RDMA_ENABLED=0 \
+PYTHONPATH=. pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \
+ --config tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml
-pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \
- --config apps/grpo/qwen3_8b.yaml
"""
+# Temp directory won't work for multi-node because NFS does not cover the tmp path
+TEST_DCP_DIR = "test_dcp_tmp"
-class MockRLTrainer(RLTrainer):
+
+class MockTitanTrainer(TitanTrainer):
@endpoint
async def zero_out_model_states(self):
"""This simply sets all model weights to zero."""
@@ -52,101 +69,169 @@ async def zero_out_model_states(self):
for k in sd.keys():
if not torch.is_floating_point(sd[k]):
logger.info(
- f"[MockRLTrainer] zero_out_model_states(): skipping non-float param {k}"
+ f"[MockTitanTrainer] zero_out_model_states(): skipping non-float param {k}"
)
continue
sd[k] *= 0.0
-# exceptions sometimes are not propogated in monarch, do it manually
-def validate_fn(prev_params, curr_model, logger) -> Exception | None:
- """Validate that current parameters are the same as prev_params."""
- verified = set()
- skipped = set()
+def _load_config(config_path: str) -> DictConfig:
+ cfg = None
+ try:
+ cfg = OmegaConf.load(config_path)
+ except Exception as e:
+ pytest.fail(f"Failed to load config file {config_path}: {e}")
+
+ assert isinstance(cfg, DictConfig)
+
+ cfg = resolve_hf_hub_paths(cfg)
+ return cfg
+
+
+def _test_validate_params_unchanged(
+ prev_params, curr_model, logger
+) -> Exception | None:
+ """Validate that current parameters are the same as prev_params.
+
+ Uses the new weight_verification utility for robust checking.
+ """
+ prev_snapshot = WeightSnapshot(params=prev_params, version=None)
+ result = verify_weights_changed(
+ prev_snapshot, curr_model, atol=1e-3, rtol=1e-2, verbose=False
+ )
+
logger.info(
- f"Validating model params, all named_parameters() = {curr_model.named_parameters()}"
+ f"Validation: {result.num_params_checked} params checked, "
+ f"{result.num_params_changed} changed, {result.num_params_unchanged} unchanged"
)
- errs = []
- for name, param in curr_model.named_parameters():
- if not torch.is_floating_point(param):
- logger.info(f"Skipping non-float param {name}")
- skipped.add(name)
- continue
- try:
- assert name in prev_params, f"Param {name} not found in prev_params"
- assert torch.allclose(
- prev_params[name], param.cpu(), atol=1e-3, rtol=1e-2
- ), (
- f"current param {name} does not match expected value; "
- f"previous param ({prev_params[name].size()})= {prev_params[name]}; "
- f"expected = {prev_params[name]} vs got = {param.cpu().size()} {param.cpu()}"
- )
- verified.add(name)
- except Exception as e:
- # logger.error(f"Validation failed with exception: {e}")
- errs.append((name, e))
- logger.info(f"Verified params = {verified}")
- logger.info(f"Skipped params = {skipped}")
- if errs:
- logger.error(
- f"Validation failed for the following params: {[e[0] for e in errs]}"
+
+ # We EXPECT no changes for this validation
+ if result.weights_changed:
+ error_msg = (
+ f"Weights unexpectedly changed! {result.num_params_changed} params changed "
+ f"(max_delta={result.max_delta:.6e}). Changed params: {result.changed_params[:5]}"
)
- return AssertionError(f"Validation failed: {errs}")
+ logger.error(error_msg)
+ return AssertionError(error_msg)
+
+
+def _test_validate_params_all_zeros(
+ prev_params, curr_model, logger
+) -> Exception | None:
+ """Validate all parameters are set to zero."""
+ _ = prev_params # Unused
+ all_zeros, zero_params, non_zero_params = verify_weights_all_zeros(
+ curr_model, atol=1e-4, rtol=1e-3, verbose=False
+ )
-# exceptions sometimes are not propogated in monarch, do it manually
-def validate_fn_all_zeros(prev_params, curr_model, logger) -> Exception | None:
- """Validate all parameters are set to zero. prev_params is actually not used."""
- _ = prev_params
- verified = set()
- skipped = set()
logger.info(
- f"Validating model params, all named_parameters() = {curr_model.named_parameters()}"
+ f"Zero validation: {len(zero_params)} zero params, {len(non_zero_params)} non-zero params"
)
- errs = []
- for name, param in curr_model.named_parameters():
- if not torch.is_floating_point(param):
- logger.info(f"Skipping non-float param {name}")
- skipped.add(name)
- continue
- try:
- param = param.cpu()
- assert torch.allclose(
- torch.zeros_like(param), param, atol=1e-4, rtol=1e-3
- ), "param {name} is not zero."
- verified.add(name)
- except Exception as e:
- # logger.error(f"Validation failed with exception: {e}")
- errs.append((name, e))
- logger.info(f"Verified params = {verified}")
- logger.info(f"Skipped params = {skipped}")
- if errs:
- logger.error(
- f"Validation failed for the following params: {[e[0] for e in errs]}"
+
+ if not all_zeros:
+ error_msg = (
+ f"Not all params are zero! {len(non_zero_params)} non-zero params found. "
+ f"First few non-zero: {non_zero_params[:5]}"
)
- return AssertionError(f"Validation failed: {errs}")
+ logger.error(error_msg)
+ return AssertionError(error_msg)
+ return None
-class TestWeightSync:
- """Tests for weight sync between trainer and policy."""
- def _load_config(self, config_path: str) -> DictConfig:
- cfg = None
- try:
- cfg = OmegaConf.load(config_path)
- except Exception as e:
- pytest.fail(f"Failed to load config file {config_path}: {e}")
+@pytest_asyncio.fixture(autouse=True)
+async def _setup_and_teardown(request):
+ # ---- setup ---- #
+ config_path = request.config.getoption("--config", default=None)
+ if not config_path:
+ pytest.skip(
+ "No config file provided. Use --config to specify a YAML config file"
+ )
- assert isinstance(cfg, DictConfig)
+ use_dcp_override = request.config.getoption("--use_dcp")
+ cfg = _load_config(config_path=config_path)
- cfg = resolve_hf_hub_paths(cfg)
- return cfg
+ trainer_proc_size = cfg.actors.trainer.procs
+ policy_tp_size = cfg.policy.engine_args.tensor_parallel_size
+
+ if policy_tp_size != cfg.services.policy.procs:
+ pytest.fail(
+ f"Expect policy proc = {cfg.services.policy.procs} to be equal to tensor parallel size = {policy_tp_size}"
+ )
+
+ model_card = cfg.model
+ logger.info(f"Running sanity check with config: {config_path}")
+ logger.info(f"Model name: {model_card}")
+ logger.info(f"Trainer proc size: {trainer_proc_size}")
+ logger.info(f"Policy tensor parallel size: {policy_tp_size}")
+
+ logger.info("Downloading model checkpoint from HuggingFace Hub")
+ cached_dir = snapshot_download(repo_id=model_card)
+ logger.info("Finished downloading model checkpoint from HuggingFace Hub")
+
+ services_policy_cfg = cfg.services.policy
+ services_policy_cfg.num_replicas = 1
+
+ trainer_cfg = cfg.trainer
+ trainer_cfg.dcp_path = TEST_DCP_DIR
+ trainer_cfg.checkpoint = {
+ "enable": True,
+ "folder": "/tmp/saved_checkpoints",
+ "initial_load_path": cached_dir,
+ "initial_load_in_hf": True,
+ }
+
+ if use_dcp_override is not None:
+ trainer_cfg["use_dcp"] = use_dcp_override
+ logger.info(f"`trainer.use_dcp` is overridden to {use_dcp_override}")
+
+ if cfg.get("provisioner", None) is not None:
+ await init_provisioner(
+ ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
+ )
+ await ts.initialize(strategy=ts.ControllerStorageVolumes())
+
+ policy, titan_trainer = await asyncio.gather(
+ *[
+ Generator.options(**services_policy_cfg).as_service(**cfg.policy),
+ MockTitanTrainer.options(**cfg.actors.trainer).as_actor(**trainer_cfg),
+ ]
+ )
+
+ yield policy, titan_trainer
+
+ # ---- teardown ---- #
+ logger.info("Shutting down services and cleaning up DCP directory..")
+
+ # Call cleanup to destroy process group before shutdown
+ # This prevents TCPStore connection errors from NCCL heartbeat threads
+ await titan_trainer.cleanup.call()
+
+ # Shutdown sequentially to avoid race conditions
+ await policy.shutdown()
+ await TitanTrainer.shutdown(titan_trainer)
+ await ts.shutdown()
+
+ # Cleanup DCP directory
+ path = Path(TEST_DCP_DIR)
+ if not path.exists() or not path.is_dir():
+ return
+ try:
+ shutil.rmtree(path)
+ logger.info(f"Successfully removed {TEST_DCP_DIR}")
+ except Exception as e:
+ logger.error(f"Failed to remove {TEST_DCP_DIR}: {e}")
+
+
+class TestWeightSync:
+ """Tests for weight sync between trainer and policy."""
@pytest.mark.asyncio
@requires_cuda
- async def test_sanity_check(self, request):
+ async def test_sanity_check(self, _setup_and_teardown):
"""
- Sanity check for weight sync sharding between RLTrainer and Policy for a given model config.
+ Sanity check for weight sync sharding between TitanTrainer and Policy for a given model config.
The check performs the following steps:
- Initialize trainer and push weights v0 (original huggingface ckpt)
@@ -155,89 +240,41 @@ async def test_sanity_check(self, request):
- Load weights v1 and check the policy has all the weights back
"""
- # Test setup
- config_path = request.config.getoption("--config", default=None)
- if not config_path:
- pytest.skip(
- "No config file provided. Use --config to specify a YAML config file"
- )
-
- use_dcp_override = request.config.getoption("--use_dcp")
- cfg = self._load_config(config_path=config_path)
-
- trainer_proc_size = cfg.actors.trainer.procs
- policy_tp_size = cfg.policy.engine_config.tensor_parallel_size
-
- if policy_tp_size != cfg.services.policy.procs:
- pytest.fail(
- f"Expect policy proc = {cfg.services.policy.procs} to be equal to tensor parallel size = {policy_tp_size}"
- )
-
- model_card = cfg.model
-
- logger.info(f"Running sanity check with config: {config_path}")
- logger.info(f"Model name: {model_card}")
- logger.info(f"Trainer proc size: {trainer_proc_size}")
- logger.info(f"Policy tensor parallel size: {policy_tp_size}")
-
- logger.info("Downloading model checkpoint from HuggingFace Hub")
- cached_dir = snapshot_download(repo_id=model_card)
- logger.info("Finished downloading model checkpoint from HuggingFace Hub")
-
- await ts.initialize()
- services_policy_cfg = cfg.services.policy
- services_policy_cfg.num_replicas = 1
-
- trainer_cfg = cfg.trainer
- trainer_cfg.checkpoint = {
- "enable": True,
- "folder": "/tmp/saved_checkpoints",
- "initial_load_path": cached_dir,
- "initial_load_in_hf": True,
- }
- if use_dcp_override is not None:
- trainer_cfg["use_dcp"] = use_dcp_override
- logger.info(f"`trainer.use_dcp` is overriden to {use_dcp_override}")
-
- with TemporaryDirectory(dir="/dev/shm/") as tmpdir:
- trainer_cfg["dcp_path"] = tmpdir
- policy, rl_trainer = await asyncio.gather(
- *[
- Policy.options(**services_policy_cfg).as_service(**cfg.policy),
- MockRLTrainer.options(**cfg.actors.trainer).as_actor(**trainer_cfg),
- ]
- )
-
- # Main logic begins here
- v0 = uuid.uuid4().int
- v1 = v0 + 1
-
- await rl_trainer.push_weights.call(policy_version=v0)
- # Setting everything to zero
- await rl_trainer.zero_out_model_states.call()
- await rl_trainer.push_weights.call(policy_version=v1)
- await policy._test_save_model_params.fanout()
-
- # Sanity check that before update all the tests pass
- all_errs = await policy._test_validate_model_params.fanout(validate_fn)
- for errs in all_errs:
- for _, e in errs.items():
- assert not e, f"Validation failed with exception: {e}"
-
- await policy.update_weights.fanout(policy_version=v1)
- all_errs = await policy._test_validate_model_params.fanout(
- validate_fn_all_zeros
- )
- for errs in all_errs:
- for _, e in errs.items():
- assert not e, f"Validation failed with exception: {e}"
-
- # Reloading v0, getting back original weights
- await policy.update_weights.fanout(policy_version=v0)
- all_errs = await policy._test_validate_model_params.fanout(validate_fn)
- for errs in all_errs:
- for _, e in errs.items():
- assert not e, f"Validation failed with exception: {e}"
-
- logger.info("✅ Weight sharding sanity check passed!")
- await ts.shutdown()
+
+ policy, titan_trainer = _setup_and_teardown
+
+ v0 = uuid.uuid4().int
+ v1 = v0 + 1
+
+ await titan_trainer.push_weights.call(policy_version=v0)
+ # Setting everything to zero
+ await titan_trainer.zero_out_model_states.call()
+ await titan_trainer.push_weights.call(policy_version=v1)
+ await policy.save_model_params.fanout()
+
+ # Sanity check that before update all the tests pass
+ all_errs = await policy.validate_model_params.fanout(
+ _test_validate_params_unchanged
+ )
+ for errs in all_errs:
+ for _, e in errs.items():
+ assert not e, f"Validation failed with exception: {e}"
+
+ await policy.update_weights.fanout(version=v1)
+ all_errs = await policy.validate_model_params.fanout(
+ _test_validate_params_all_zeros
+ )
+ for errs in all_errs:
+ for _, e in errs.items():
+ assert not e, f"Validation failed with exception: {e}"
+
+ # Reloading v0, getting back original weights
+ await policy.update_weights.fanout(version=v0)
+ all_errs = await policy.validate_model_params.fanout(
+ _test_validate_params_unchanged
+ )
+ for errs in all_errs:
+ for _, e in errs.items():
+ assert not e, f"Validation failed with exception: {e}"
+
+ logger.info("✅ Weight sharding sanity check passed!")
diff --git a/tests/integration_tests/test_titan_fwd_vs_hf_fwd.py b/tests/integration_tests/test_titan_fwd_vs_hf_fwd.py
index 4fcd850e7..3602b111f 100644
--- a/tests/integration_tests/test_titan_fwd_vs_hf_fwd.py
+++ b/tests/integration_tests/test_titan_fwd_vs_hf_fwd.py
@@ -25,9 +25,9 @@
import torch
from forge.actors.reference_model import ReferenceModel
-from forge.cli.config import _resolve_hf_model_path
from forge.controller import ForgeActor
from forge.controller.provisioner import shutdown
+from forge.util.config import _resolve_hf_model_path
from monarch.actor import endpoint
from torchtitan.config.job_config import Checkpoint, Compile, Model, Parallelism
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -213,7 +213,7 @@ def compare_logits(
hf_val = hf_logits_cpu[pos].item()
diff_val = abs_diff[pos].item()
print(
- f" {i+1}. Position {pos}: titan={titan_val:.6f}, hf={hf_val:.6f}, diff={diff_val:.6f}"
+ f" {i + 1}. Position {pos}: titan={titan_val:.6f}, hf={hf_val:.6f}, diff={diff_val:.6f}"
)
return metrics
@@ -242,12 +242,12 @@ def compare_probabilities(
zip(titan_top_k.values, titan_top_k.indices)
):
token = tokenizer.decode([token_id.item()])
- print(f" {i+1}. '{token}' (id={token_id.item()}): {prob.item():.6f}")
+ print(f" {i + 1}. '{token}' (id={token_id.item()}): {prob.item():.6f}")
print("\nHugging Face Top-K:")
for i, (prob, token_id) in enumerate(zip(hf_top_k.values, hf_top_k.indices)):
token = tokenizer.decode([token_id.item()])
- print(f" {i+1}. '{token}' (id={token_id.item()}): {prob.item():.6f}")
+ print(f" {i + 1}. '{token}' (id={token_id.item()}): {prob.item():.6f}")
# Calculate overlap in top-k predictions
titan_top_tokens = set(titan_top_k.indices.tolist())
diff --git a/tests/integration_tests/test_vllm_policy_correctness.py b/tests/integration_tests/test_vllm_policy_correctness.py
new file mode 100644
index 000000000..71ff3677b
--- /dev/null
+++ b/tests/integration_tests/test_vllm_policy_correctness.py
@@ -0,0 +1,239 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import pytest
+
+from forge.actors.generator import Generator as Policy
+from vllm import SamplingParams
+from vllm.engine.arg_utils import AsyncEngineArgs
+from vllm.sampling_params import RequestOutputKind
+from vllm.v1.engine.async_llm import AsyncLLM
+
+
+# Configuration
+MODEL_NAME = "facebook/opt-125m"
+MAX_MODEL_LEN = 512
+GPU_MEMORY_UTILIZATION = 0.1
+ENFORCE_EAGER = True
+ENABLE_PREFIX_CACHING = True
+TENSOR_PARALLEL_SIZE = 1
+
+# Sampling parameters
+MAX_TOKENS = 50
+TEMPERATURE = 0.0 # Deterministic
+TOP_P = 1.0
+N_SAMPLES = 1
+
+
+@pytest.mark.asyncio
+async def test_same_output():
+ """Compare outputs between vLLM and Policy service"""
+ test_prompts = [
+ "Hello, how are you?",
+ "What is 2+2?",
+ "Tell me a joke.",
+ "Explain machine learning briefly.",
+ "What color is the sky?",
+ ]
+ policy = None
+ try:
+ # Setup vLLM directly
+ args = AsyncEngineArgs(
+ model=MODEL_NAME,
+ max_model_len=MAX_MODEL_LEN,
+ gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
+ enforce_eager=ENFORCE_EAGER,
+ enable_prefix_caching=ENABLE_PREFIX_CACHING,
+ )
+ vllm_model = AsyncLLM.from_engine_args(args)
+
+ # Setup Policy service
+ policy = await Policy.options(
+ procs=1, num_replicas=1, with_gpus=True
+ ).as_service(
+ engine_args={
+ "model": MODEL_NAME,
+ "tensor_parallel_size": TENSOR_PARALLEL_SIZE,
+ "enforce_eager": ENFORCE_EAGER,
+ "max_model_len": MAX_MODEL_LEN,
+ "gpu_memory_utilization": GPU_MEMORY_UTILIZATION,
+ "enable_prefix_caching": ENABLE_PREFIX_CACHING,
+ },
+ sampling_params={
+ "n": N_SAMPLES,
+ "max_tokens": MAX_TOKENS,
+ "temperature": TEMPERATURE,
+ "top_p": TOP_P,
+ },
+ )
+
+ print("Models ready. Generating outputs...\n")
+ vllm_outputs = []
+ policy_outputs = []
+ sampling_params = SamplingParams(
+ max_tokens=MAX_TOKENS,
+ temperature=TEMPERATURE,
+ top_p=TOP_P,
+ n=N_SAMPLES,
+ output_kind=RequestOutputKind.FINAL_ONLY,
+ )
+
+ for i, prompt in enumerate(test_prompts, 1):
+ # vLLM generation
+ async for res in vllm_model.generate(
+ prompt, sampling_params, request_id=str(i)
+ ):
+ vllm_outputs.append(res.outputs[0].text)
+
+ # Policy generation
+ policy_result = await policy.generate.route(prompt)
+ policy_text = policy_result[0].text
+ policy_outputs.append(policy_text)
+
+ # Final check
+ for vllm_output, policy_output in zip(vllm_outputs, policy_outputs):
+ assert vllm_output != ""
+ assert policy_output != ""
+ assert vllm_output == policy_output
+
+ finally:
+ if policy is not None:
+ await policy.shutdown()
+
+
+@pytest.mark.asyncio
+async def test_cache_usage():
+ """Test that KV cache usage is consistent between vLLM and Policy service.
+
+ Namely we want to check two things:
+ 1. KV cache is populated correctly.
+ 2. KV cache is cleared correctly.
+
+ Our main tool to inspect the KV cache is the `num_cached_tokens` field in the request output.
+ According to the vLLM docs (https://docs.vllm.ai/en/v0.9.0/api/vllm/outputs.html#vllm.outputs.RequestOutput),
+ this is the number of tokens with a prefix cache hit. So, the logic is that if we run one generation,
+ then run another generation with the same start, we should see the number of cached tokens == the length of the prefix.
+
+ Some important caveats:
+ - vLLM does not appear to do partial prefix caching. So if a shared prefix is less than BLOCK_SIZE,
+ it won't be cached.
+ - This is a limited test. Ideally, it would also be good to check the size of the block pool before and after
+ each generation. In addition, it would be interesting to examine the GPU memory freed after
+ calling reset_prefix_cache(); however, it is not exactly clear how to access these internal APIs
+ via the AsyncLLM interface.
+ - We do not test different different block sizes.
+ """
+ policy = None
+ try:
+ # Setup vLLM directly
+ args = AsyncEngineArgs(
+ model=MODEL_NAME,
+ max_model_len=MAX_MODEL_LEN,
+ gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
+ enforce_eager=ENFORCE_EAGER,
+ enable_prefix_caching=ENABLE_PREFIX_CACHING,
+ block_size=16,
+ )
+ vllm_model = AsyncLLM.from_engine_args(args)
+
+ # Setup Policy service
+ policy = await Policy.options(
+ procs=1, num_replicas=1, with_gpus=True
+ ).as_service(
+ engine_args={
+ "model": MODEL_NAME,
+ "tensor_parallel_size": TENSOR_PARALLEL_SIZE,
+ "enforce_eager": ENFORCE_EAGER,
+ "max_model_len": MAX_MODEL_LEN,
+ "gpu_memory_utilization": GPU_MEMORY_UTILIZATION,
+ "enable_prefix_caching": ENABLE_PREFIX_CACHING,
+ "block_size": 16,
+ },
+ sampling_params={
+ "n": N_SAMPLES,
+ "max_tokens": MAX_TOKENS,
+ "temperature": TEMPERATURE,
+ "top_p": TOP_P,
+ },
+ )
+
+ print("Models ready. Starting KV cache test...")
+
+ sampling_params = SamplingParams(
+ max_tokens=MAX_TOKENS,
+ temperature=TEMPERATURE,
+ top_p=TOP_P,
+ n=N_SAMPLES,
+ output_kind=RequestOutputKind.FINAL_ONLY,
+ )
+ vllm_outputs = []
+ policy_outputs = []
+
+ # Exactly 16 tokens to fill up 1 block
+ first_prompt = (
+ "The paged prefix caching mechanism in vLLM is an interesting approach."
+ )
+ expected_cached_tokens = 0
+ async for res in vllm_model.generate(
+ first_prompt, sampling_params, request_id="first_16"
+ ):
+ vllm_outputs.append(res.outputs[0].text)
+ assert res.num_cached_tokens == expected_cached_tokens
+ res = await policy.generate.route(first_prompt)
+ assert res[0].metadata["num_cached_tokens"] == expected_cached_tokens
+ policy_outputs.append(res[0].text)
+
+ # Another 16 tokens to now populate 2 blocks (+ reuse the first block)
+ second_prompt = (
+ first_prompt
+ + " It removes the need to recalculate attention key-values for already processed text."
+ )
+ expected_cached_tokens = 16
+ async for res in vllm_model.generate(
+ second_prompt, sampling_params, request_id="second_16_use_first_block"
+ ):
+ vllm_outputs.append(res.outputs[0].text)
+ assert res.num_cached_tokens == expected_cached_tokens
+ res = await policy.generate.route(second_prompt)
+ assert res[0].metadata["num_cached_tokens"] == expected_cached_tokens
+ policy_outputs.append(res[0].text)
+
+ # The first same 32 tokens should now be populated in blocks
+ third_prompt = second_prompt
+ expected_cached_tokens = 32
+ async for res in vllm_model.generate(
+ third_prompt, sampling_params, request_id="use_both_blocks"
+ ):
+ vllm_outputs.append(res.outputs[0].text)
+ assert res.num_cached_tokens == expected_cached_tokens
+ res = await policy.generate.route(third_prompt)
+ assert res[0].metadata["num_cached_tokens"] == expected_cached_tokens
+ policy_outputs.append(res[0].text)
+
+ # Now, let's clear the cache
+ await vllm_model.reset_prefix_cache()
+ await policy._reset_prefix_cache.route()
+
+ # And try the third prompt again (should not use any cached tokens)
+ expected_cached_tokens = 0
+ async for res in vllm_model.generate(
+ third_prompt, sampling_params, request_id="use_no_blocks_bc_cache_cleared"
+ ):
+ vllm_outputs.append(res.outputs[0].text)
+ assert res.num_cached_tokens == expected_cached_tokens
+ res = await policy.generate.route(third_prompt)
+ assert res[0].metadata["num_cached_tokens"] == expected_cached_tokens
+ policy_outputs.append(res[0].text)
+
+ # Sanity check that outputs are still the same
+ for vllm_output, policy_output in zip(vllm_outputs, policy_outputs):
+ assert vllm_output != ""
+ assert policy_output != ""
+ assert vllm_output == policy_output
+
+ finally:
+ if policy is not None:
+ await policy.shutdown()
diff --git a/apps/rl_trainer/main.py b/tests/sandbox/rl_trainer/main.py
similarity index 83%
rename from apps/rl_trainer/main.py
rename to tests/sandbox/rl_trainer/main.py
index 8473cc16d..8825794b6 100644
--- a/apps/rl_trainer/main.py
+++ b/tests/sandbox/rl_trainer/main.py
@@ -4,14 +4,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
-# Usage: python -m apps.rl_trainer.main --config apps/grpo/qwen3_32b.yaml
+# Usage: python -m tests.sandbox.rl_trainer.main --config apps/grpo/qwen3_32b.yaml
import asyncio
import torch
import torchstore as ts
-from forge.actors.trainer import RLTrainer
-from forge.cli.config import parse
+from forge.actors.trainer import TitanTrainer
from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY
from forge.controller.provisioner import init_provisioner, shutdown
from forge.observability.metric_actors import get_or_create_metric_logger
@@ -23,6 +22,7 @@
ProvisionerConfig,
ServiceConfig,
)
+from forge.util.config import parse
from omegaconf import DictConfig
from vllm.transformers_utils.tokenizer import get_tokenizer
@@ -40,17 +40,17 @@ def simple_grpo_loss(
Just performs basic tensor operations to simulate memory usage.
"""
# Extract dimensions
- batch_size, response_len = response.shape
+ local_batch_size, response_len = response.shape
vocab_size = logits.size(-1)
full_seq_len = logits.size(1)
# Extract only the response portion from logits
- # logits shape: [batch_size, request_len + response_len, vocab_size]
+ # logits shape: [local_batch_size, request_len + response_len, vocab_size]
# We want the last response_len tokens
request_len = full_seq_len - response_len
response_logits = logits[
:, request_len:, :
- ] # [batch_size, response_len, vocab_size]
+ ] # [local_batch_size, response_len, vocab_size]
# Flatten logits and response for cross-entropy
logits_flat = response_logits.reshape(-1, vocab_size)
@@ -59,7 +59,7 @@ def simple_grpo_loss(
# Basic cross-entropy loss (simplified)
loss = torch.nn.functional.cross_entropy(
logits_flat, response_flat, reduction="none"
- ).view(batch_size, response_len)
+ ).view(local_batch_size, response_len)
# Apply padding mask and reduce
masked_loss = loss * padding_mask
@@ -69,7 +69,7 @@ def simple_grpo_loss(
def generate_random_batch(
- batch_size: int,
+ local_batch_size: int,
request_len: int,
response_len: int,
vocab_size: int = 32000,
@@ -86,19 +86,28 @@ def generate_random_batch(
# Create one batch for each data parallel rank
for _ in range(dp_size):
request = torch.randint(
- 1, vocab_size, (batch_size, request_len), dtype=torch.long, device=device
+ 1,
+ vocab_size,
+ (local_batch_size, request_len),
+ dtype=torch.long,
+ device=device,
)
response = torch.randint(
- 1, vocab_size, (batch_size, response_len), dtype=torch.long, device=device
+ 1,
+ vocab_size,
+ (local_batch_size, response_len),
+ dtype=torch.long,
+ device=device,
)
# Create padding mask (randomly mask some tokens as padding)
- padding_mask = torch.rand((batch_size, response_len), device=device) > 0.1
+ padding_mask = torch.rand((local_batch_size, response_len), device=device) > 0.1
ref_logprobs = (
- -torch.abs(torch.randn((batch_size, response_len), device=device)) - 1.0
+ -torch.abs(torch.randn((local_batch_size, response_len), device=device))
+ - 1.0
)
- advantages = torch.randn((batch_size, 1), device=device)
+ advantages = torch.randn((local_batch_size, 1), device=device)
input_tokens = torch.cat([request, response], dim=1)
inputs.append({"tokens": input_tokens})
targets.append(
@@ -133,7 +142,9 @@ async def main(cfg: DictConfig):
"""
# Extract training parameters from existing GRPO config fields
- batch_size = cfg.get("batch_size", 4)
+ local_batch_size = cfg.get("local_batch_size", None)
+ assert local_batch_size is not None, "local_batch_size must be specified"
+
request_len = cfg.get("max_req_tokens", 128)
response_len = cfg.get("max_res_tokens", 128)
max_training_steps = cfg.trainer.training.get("steps", 100)
@@ -156,7 +167,7 @@ async def main(cfg: DictConfig):
await init_provisioner(
ProvisionerConfig(
launcher_config=LauncherConfig(
- launcher=Launcher(cfg.get(LAUNCHER_KEY, Launcher.SLURM.value)),
+ launcher=cfg.get(LAUNCHER_KEY, Launcher.SLURM.value),
job_name=cfg.get(JOB_NAME_KEY, None),
services={k: ServiceConfig(**v) for k, v in cfg.services.items()},
actors={k: ProcessConfig(**v) for k, v in cfg.actors.items()},
@@ -171,11 +182,11 @@ async def main(cfg: DictConfig):
await ts.initialize(strategy=ts.ControllerStorageVolumes())
# Initialize trainer only
print("Initializing trainer...")
- trainer = await RLTrainer.options(**cfg.actors.trainer).as_actor(
+ trainer = await TitanTrainer.options(**cfg.actors.trainer).as_actor(
**cfg.trainer, loss=simple_grpo_loss
)
print("Trainer initialized successfully with following configs!")
- print(f" - Batch size: {batch_size}")
+ print(f" - Local batch size: {local_batch_size}")
print(f" - Request length: {request_len}")
print(f" - Response length: {response_len}")
print(f" - Vocab size: {vocab_size}")
@@ -191,7 +202,7 @@ async def continuous_training():
t.start()
inputs, targets = generate_random_batch(
- batch_size=batch_size,
+ local_batch_size=local_batch_size,
request_len=request_len,
response_len=response_len,
vocab_size=vocab_size,
@@ -221,9 +232,7 @@ async def continuous_training():
except KeyboardInterrupt:
print("Training interrupted by user")
finally:
- print("Shutting down trainer...")
- await RLTrainer.shutdown(trainer)
- await mlogger.shutdown.call_one()
+ print("Shutting down...")
await shutdown()
print("Trainer shutdown complete.")
diff --git a/apps/toy_rl/__init__.py b/tests/sandbox/toy_rl/__init__.py
similarity index 100%
rename from apps/toy_rl/__init__.py
rename to tests/sandbox/toy_rl/__init__.py
diff --git a/apps/toy_rl/sumdigits-tp.yaml b/tests/sandbox/toy_rl/sumdigits-tp.yaml
similarity index 91%
rename from apps/toy_rl/sumdigits-tp.yaml
rename to tests/sandbox/toy_rl/sumdigits-tp.yaml
index 87f58d5ea..74fb57e4a 100644
--- a/apps/toy_rl/sumdigits-tp.yaml
+++ b/tests/sandbox/toy_rl/sumdigits-tp.yaml
@@ -13,23 +13,21 @@ dataset:
# Policy configuration
policy:
- engine_config:
+ engine_args:
model: ${model}
tensor_parallel_size: 2
pipeline_parallel_size: 1
enforce_eager: false
- sampling_config:
+ sampling_params:
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
top_p: 1.0
- use_vllm_builtin_load: true
# Trainer configuration
trainer:
model_name: ${model}
learning_rate: 1e-5
- use_vllm_builtin_load: true
# Reference model configuration
ref_model:
diff --git a/apps/toy_rl/sumdigits.py b/tests/sandbox/toy_rl/sumdigits.py
similarity index 86%
rename from apps/toy_rl/sumdigits.py
rename to tests/sandbox/toy_rl/sumdigits.py
index 57971e9b9..56b669ce4 100644
--- a/apps/toy_rl/sumdigits.py
+++ b/tests/sandbox/toy_rl/sumdigits.py
@@ -4,11 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
-# Usage: python -m apps.toy_rl.sumdigits --config apps/toy_rl/sumdigits.yaml
+# Usage: python -m tests.sandbox.toy_rl.sumdigits --config tests/sandbox/toy_rl/sumdigits.yaml
import asyncio
import random
-import time
import uuid
from dataclasses import dataclass
from typing import Any
@@ -17,20 +16,19 @@
import torch.nn.functional as F
import torchstore as ts
from forge.actors._torchstore_utils import get_param_key
-from forge.actors.policy import Policy
+from forge.actors.generator import Generator
from forge.actors.replay_buffer import ReplayBuffer
-from forge.actors.trainer import _qwen3_hf_to_vllm
-from forge.cli.config import parse
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import shutdown
-
from forge.losses.grpo_loss import SimpleGRPOLoss
-from forge.util.metric_logging import get_metric_logger
-from forge.util.ops import selective_log_softmax
+from forge.observability.metric_actors import get_or_create_metric_logger
+
+from forge.observability.metrics import record_metric, Reduce
+from forge.util.config import parse
+from forge.util.ops import compute_logprobs
from monarch.actor import endpoint
from omegaconf import DictConfig
-from torchstore.state_dict_utils import DELIM
from transformers import AutoModelForCausalLM
from vllm.transformers_utils.tokenizer import get_tokenizer
@@ -220,7 +218,6 @@ def __init__(self, model_name, device: torch.device | None = None):
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
- dtype=torch.bfloat16,
trust_remote_code=True,
).to(self.device)
self.model.eval()
@@ -244,7 +241,8 @@ async def forward(self, episode: Episode) -> torch.Tensor:
with torch.inference_mode():
logits = self.model(input_ids=input_ids, attention_mask=mask).logits
- return selective_log_softmax(logits, target_ids).squeeze(0)
+ log_probs = compute_logprobs(logits, target_ids, align=False)
+ return log_probs.squeeze(0)
@dataclass
@@ -255,7 +253,6 @@ class Trainer(ForgeActor):
learning_rate: float = 1e-5
device: torch.device | None = None
state_dict_key: str = "model_state_dict"
- use_vllm_builtin_load: bool = True
def __post_init__(self):
super().__init__()
@@ -267,7 +264,6 @@ async def setup(self):
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
- dtype=torch.bfloat16,
trust_remote_code=True,
).to(self.device)
self.model.train()
@@ -330,7 +326,7 @@ def train_step(self, episodes: list[Episode]) -> float:
# Forward pass
logits = self.model(input_ids=input_ids, attention_mask=attention_mask).logits
- trainer_log_probs = selective_log_softmax(logits, target_ids)
+ trainer_log_probs = compute_logprobs(logits, target_ids, align=False)
# Compute loss only on response tokens
# loss = self.loss(logits, target_ids, loss_masks, weights, sampling_log_probs)
loss = self.loss(trainer_log_probs, ref_logprobs, weights, loss_masks)
@@ -342,38 +338,9 @@ def train_step(self, episodes: list[Episode]) -> float:
self.optimizer.zero_grad(set_to_none=True)
return loss.item()
- @endpoint
- async def push_weights_DEPRECATED( # noqa: N802
- self, policy_version: int, vllm_tp_DEPRECATED: int = 1
- ):
- """Update policy model weights with trainer's current weights.
- This method pushes weights to torchstore in the vllm format,
- which is buggy and not scalable to other models. Deprecated.
- """
- return await self._push_weights_DEPRECATED(policy_version, vllm_tp_DEPRECATED)
-
- async def _push_weights_DEPRECATED( # noqa: N802
- self, version: int, vllm_tp_DEPRECATED: int
- ) -> None:
- """Update policy model weights with trainer's current weights."""
- key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id
- new_sd = _qwen3_hf_to_vllm(
- self.model.state_dict(),
- num_layers=self.model.config.num_hidden_layers,
- vllm_tp=vllm_tp_DEPRECATED,
- )
- start_time = time.time()
- await ts.put_state_dict(new_sd, key)
- end_time = time.time()
- self.logger.debug(
- f"Pushed weights to {key} in {end_time - start_time:.2f} seconds"
- )
-
@endpoint
async def push_weights(self, policy_version: int) -> None:
"""Push weights to torchstore in HF format."""
- if not self.use_vllm_builtin_load:
- return await self._push_weights_DEPRECATED(policy_version)
hf_state_dict = self.model.state_dict()
for name, param in hf_state_dict.items():
key = get_param_key(policy_version, name)
@@ -461,17 +428,14 @@ async def main(cfg: DictConfig):
group_size = cfg.group_size
max_req_tokens = cfg.max_req_tokens
max_res_tokens = cfg.max_res_tokens
- # TODO: delete this logic after we are confident on the vllm weight sync long term fix PR #184
- policy_tp_size = cfg.policy.engine_config.tensor_parallel_size
- mlogger = get_metric_logger(
- "wandb",
- freq=1,
- project="sumdigits-training",
- )
# ---- Setup services ---- #
print(f"{cfg.policy=}")
print(f"{cfg.services.policy=}")
+
+ metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
+ mlogger = await get_or_create_metric_logger()
+ await mlogger.init_backends.call_one(metric_logging_cfg)
await ts.initialize()
(
dataloader,
@@ -482,7 +446,7 @@ async def main(cfg: DictConfig):
ref_model,
) = await asyncio.gather(
DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),
- Policy.options(**cfg.services.policy).as_service(**cfg.policy),
+ Generator.options(**cfg.services.policy).as_service(**cfg.policy),
Trainer.options(**cfg.actors.trainer).as_actor(**cfg.trainer),
ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(**cfg.replay_buffer),
RewardActor.options(**cfg.services.reward_actor).as_service(),
@@ -533,9 +497,9 @@ async def continuous_rollouts():
avg_response_len = (
sum(len(e.response_tokens) for e in group.episodes) / group_size
)
- mlogger.log("avg_response_len/rollout", avg_response_len, rollout_count)
+ record_metric("avg_response_len/rollout", avg_response_len, Reduce.MEAN)
avg_reward = sum(e.reward for e in group.episodes) / group_size
- mlogger.log("avg_reward/rollout", avg_reward, rollout_count)
+ record_metric("avg_reward/rollout", avg_reward, Reduce.MEAN)
rollout_count += 1
@@ -550,7 +514,7 @@ async def continuous_training():
else:
loss = await trainer.train_step.call_one(batch[0])
training_step += 1
- mlogger.log("loss/training_step", loss, training_step)
+ record_metric("loss/training_step", loss, Reduce.MEAN)
print(f"loss/training_step: {loss} at training step {training_step}")
await trainer.push_weights.call(training_step)
await policy.update_weights.fanout(training_step)
@@ -570,15 +534,6 @@ async def continuous_training():
training_task.cancel()
finally:
print("Shutting down...")
- await asyncio.gather(
- DatasetActor.shutdown(dataloader),
- policy.shutdown(),
- Trainer.shutdown(trainer),
- ReplayBuffer.shutdown(replay_buffer),
- reward_actor.shutdown(),
- )
- # TODO - add a global shutdown that implicitly shuts down all services
- # and remote allocations
await shutdown()
diff --git a/apps/toy_rl/sumdigits.yaml b/tests/sandbox/toy_rl/sumdigits.yaml
similarity index 96%
rename from apps/toy_rl/sumdigits.yaml
rename to tests/sandbox/toy_rl/sumdigits.yaml
index 767bf7f3b..06a192431 100644
--- a/apps/toy_rl/sumdigits.yaml
+++ b/tests/sandbox/toy_rl/sumdigits.yaml
@@ -14,12 +14,12 @@ dataset:
# Policy configuration
policy:
use_dcp: false
- engine_config:
+ engine_args:
model: ${model}
tensor_parallel_size: 1
pipeline_parallel_size: 1
enforce_eager: false
- sampling_config:
+ sampling_params:
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
diff --git a/apps/toy_rl/toy_metrics/main.py b/tests/sandbox/toy_rl/toy_metrics/main.py
similarity index 82%
rename from apps/toy_rl/toy_metrics/main.py
rename to tests/sandbox/toy_rl/toy_metrics/main.py
index d999fb700..f72bada48 100644
--- a/apps/toy_rl/toy_metrics/main.py
+++ b/tests/sandbox/toy_rl/toy_metrics/main.py
@@ -59,7 +59,6 @@ async def generate_step(self, step: int, substep: int):
rank = current_rank().rank
with trace("policy_perf", track_memory=False, timer="gpu") as tracer:
-
value = rank * 1000 + step * 100 + substep * 10
tracer.step("time_to_value")
# Record generation metrics following the plan
@@ -82,25 +81,27 @@ async def main():
group = f"grpo_exp_{int(time.time())}"
# Config format: {backend_name: backend_config_dict}
- # Each backend can specify reduce_across_ranks to control distributed logging behavior
config = {
- "console": {"reduce_across_ranks": True},
+ "console": {"logging_mode": "global_reduce"},
"wandb": {
- "project": "my_project",
+ "project": "toy_metrics",
"group": group,
- "reduce_across_ranks": False,
- # Only useful if NOT reduce_across_ranks.
- "share_run_id": False, # Share run ID across ranks -- Not recommended.
+ "logging_mode": "per_rank_reduce", # global_reduce, per_rank_reduce, per_rank_no_reduce
+ "per_rank_share_run": True,
},
}
service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False}
- mlogger = await get_or_create_metric_logger()
+ mlogger = await get_or_create_metric_logger(process_name="Controller")
await mlogger.init_backends.call_one(config)
# Spawn services first (triggers registrations via provisioner hook)
- trainer = await TrainActor.options(**service_config).as_service()
- generator = await GeneratorActor.options(**service_config).as_service()
+ trainer = await TrainActor.options(
+ **service_config, mesh_name="TrainActor"
+ ).as_service()
+ generator = await GeneratorActor.options(
+ **service_config, mesh_name="GeneratorActor"
+ ).as_service()
for i in range(3):
print(f"\n=== Global Step {i} ===")
@@ -110,14 +111,6 @@ async def main():
await mlogger.flush.call_one(i)
# shutdown
- await mlogger.shutdown.call_one()
- await asyncio.sleep(2)
-
- await asyncio.gather(
- trainer.shutdown(),
- generator.shutdown(),
- )
-
await shutdown()
diff --git a/apps/vllm/deepseek_r1.yaml b/tests/sandbox/vllm/deepseek_r1.yaml
similarity index 73%
rename from apps/vllm/deepseek_r1.yaml
rename to tests/sandbox/vllm/deepseek_r1.yaml
index 7a0c2ad2d..fd4228d5a 100644
--- a/apps/vllm/deepseek_r1.yaml
+++ b/tests/sandbox/vllm/deepseek_r1.yaml
@@ -1,18 +1,20 @@
-# >>> python -m apps.vllm.main --config apps/vllm/deepseek_r1.yaml
+# >>> python -m tests.sandbox.vllm.main --config tests/sandbox/vllm/deepseek_r1.yaml
# NOTE - this won't work until we have proper HostMesh support
policy:
- engine_config:
+ engine_args:
model: "deepseek-ai/DeepSeek-R1-0528"
tensor_parallel_size: 16
pipeline_parallel_size: 1
enable_expert_parallel: true
# enforce_eager: true
- sampling_config:
+ sampling_params:
n: 2
- guided_decoding: false
max_tokens: 512
+provisioner:
+ launcher: slurm
+
services:
policy:
procs: 8
diff --git a/apps/vllm/llama3_8b.yaml b/tests/sandbox/vllm/llama3_8b.yaml
similarity index 62%
rename from apps/vllm/llama3_8b.yaml
rename to tests/sandbox/vllm/llama3_8b.yaml
index c4bc141bf..95a2ad53a 100644
--- a/apps/vllm/llama3_8b.yaml
+++ b/tests/sandbox/vllm/llama3_8b.yaml
@@ -1,19 +1,18 @@
-# >>> python -m apps.vllm.main --config apps/vllm/llama3_8b.yaml
+# >>> python -m tests.sandbox.vllm.main --config tests/sandbox/vllm/llama3_8b.yaml
policy:
- engine_config:
+ engine_args:
model: "meta-llama/Llama-3.1-8B-Instruct"
tensor_parallel_size: 2
pipeline_parallel_size: 1
enforce_eager: true
- sampling_config:
+ sampling_params:
n: 2
- guided_decoding: false
max_tokens: 512
services:
policy:
- procs: ${policy.engine_config.tensor_parallel_size}
+ procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 4
with_gpus: true
diff --git a/apps/vllm/main.py b/tests/sandbox/vllm/main.py
similarity index 68%
rename from apps/vllm/main.py
rename to tests/sandbox/vllm/main.py
index 3167817c7..f41dec56a 100644
--- a/apps/vllm/main.py
+++ b/tests/sandbox/vllm/main.py
@@ -6,19 +6,21 @@
"""To run:
export HF_HUB_DISABLE_XET=1
-python -m apps.vllm.main --config apps/vllm/llama3_8b.yaml
+python -m tests.sandbox.vllm.main --config tests/sandbox/vllm/llama3_8b.yaml
"""
import asyncio
import os
-from forge.actors.policy import Policy
-from forge.cli.config import parse
-from forge.controller.provisioner import shutdown
+from forge.actors.generator import Generator
+
+from forge.controller.provisioner import init_provisioner, shutdown
from forge.data_models.completion import Completion
from forge.observability.metric_actors import get_or_create_metric_logger
+from forge.types import LauncherConfig, ProvisionerConfig
+from forge.util.config import parse
from omegaconf import DictConfig
os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600"
@@ -26,16 +28,21 @@
async def run(cfg: DictConfig):
- metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
- mlogger = await get_or_create_metric_logger()
+ if cfg.get("provisioner", None) is not None:
+ await init_provisioner(
+ ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
+ )
+ metric_logging_cfg = cfg.get(
+ "metric_logging", {"console": {"logging_mode": "global_reduce"}}
+ )
+ mlogger = await get_or_create_metric_logger(process_name="Controller")
await mlogger.init_backends.call_one(metric_logging_cfg)
if (prompt := cfg.get("prompt")) is None:
- gd = cfg.policy.get("sampling_config", {}).get("guided_decoding", False)
- prompt = "What is 3+5?" if gd else "Tell me a joke"
+ prompt = "Tell me a joke"
print("Spawning service...")
- policy = await Policy.options(**cfg.services.policy).as_service(**cfg.policy)
+ policy = await Generator.options(**cfg.services.policy).as_service(**cfg.policy)
import time
@@ -61,7 +68,6 @@ async def run(cfg: DictConfig):
print("-" * 80)
print("\nShutting down...")
- await policy.shutdown()
await shutdown()
diff --git a/apps/vllm/qwen2_5_32b.yaml b/tests/sandbox/vllm/qwen2_5_32b.yaml
similarity index 69%
rename from apps/vllm/qwen2_5_32b.yaml
rename to tests/sandbox/vllm/qwen2_5_32b.yaml
index 72d55781b..6590b791a 100644
--- a/apps/vllm/qwen2_5_32b.yaml
+++ b/tests/sandbox/vllm/qwen2_5_32b.yaml
@@ -1,14 +1,13 @@
-# >>> python -m apps.vllm.main --config apps/vllm/qwen2_5_32b.yaml
+# >>> python -m tests.sandbox.vllm.main --config tests/sandbox/vllm/qwen2_5_32b.yaml
policy:
- engine_config:
+ engine_args:
model: "Qwen/Qwen2.5-32B"
tensor_parallel_size: 4
pipeline_parallel_size: 1
enforce_eager: true
- sampling_config:
+ sampling_params:
n: 2
- guided_decoding: false
max_tokens: 512
services:
diff --git a/tests/sandbox/weight_sync/main.py b/tests/sandbox/weight_sync/main.py
new file mode 100644
index 000000000..dfdc58f0a
--- /dev/null
+++ b/tests/sandbox/weight_sync/main.py
@@ -0,0 +1,205 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Weight Sync Sandbox
+
+A minimal test environment focused exclusively on testing the weight synchronization
+mechanism between RLTrainer and Generator.
+
+Usage:
+ python -m tests.sandbox.weight_sync.main --config tests/sandbox/weight_sync/qwen3_1_7b.yaml
+"""
+
+import asyncio
+import time
+
+import torch
+import torchstore as ts
+from forge.actors._torchstore_utils import rdma_enabled
+from forge.actors.generator import Generator
+from forge.actors.trainer import RLTrainer
+from forge.controller.provisioner import init_provisioner, shutdown
+from forge.observability.metric_actors import get_or_create_metric_logger
+from forge.types import LauncherConfig, ProvisionerConfig
+from forge.util.config import parse
+from omegaconf import DictConfig
+from vllm.transformers_utils.tokenizer import get_tokenizer
+
+
+def generate_random_batch(
+ local_batch_size: int,
+ request_len: int,
+ response_len: int,
+ vocab_size: int = 32000,
+ device: str = "cuda",
+ dp_size: int = 1,
+):
+ """
+ Generate random input and target tensors for a single training step.
+ Creates one batch per data parallel rank.
+ """
+ inputs = []
+ targets = []
+
+ # Create one batch for each data parallel rank
+ for _ in range(dp_size):
+ request = torch.randint(
+ 1,
+ vocab_size,
+ (local_batch_size, request_len),
+ dtype=torch.long,
+ device=device,
+ )
+ response = torch.randint(
+ 1,
+ vocab_size,
+ (local_batch_size, response_len),
+ dtype=torch.long,
+ device=device,
+ )
+
+ # Create padding mask
+ padding_mask = torch.rand((local_batch_size, response_len), device=device) > 0.1
+
+ ref_logprobs = (
+ -torch.abs(torch.randn((local_batch_size, response_len), device=device))
+ - 1.0
+ )
+ advantages = torch.randn((local_batch_size, 1), device=device)
+ input_tokens = torch.cat([request, response], dim=1)
+ inputs.append({"tokens": input_tokens})
+ targets.append(
+ {
+ "response": response,
+ "ref_logprobs": ref_logprobs,
+ "advantages": advantages,
+ "padding_mask": padding_mask,
+ }
+ )
+
+ return inputs, targets
+
+
+async def main(cfg: DictConfig):
+ local_batch_size = cfg.get("local_batch_size", None)
+ assert local_batch_size is not None, "local_batch_size must be specified"
+
+ request_len = cfg.get("max_req_tokens", 64)
+ response_len = cfg.get("max_res_tokens", 64)
+ model_name = cfg.get("model")
+
+ print(f"Loading tokenizer for model: {model_name}")
+ tokenizer = get_tokenizer(model_name)
+ vocab_size = tokenizer.vocab_size
+ print(f"Detected vocab size: {vocab_size}")
+
+ trainer_dp_degree = cfg.trainer.parallelism.get("data_parallel_shard_degree", 1)
+ dp_size = trainer_dp_degree if trainer_dp_degree != -1 else 1
+
+ # ---- Global setups ---- #
+ provisioner = None
+ if cfg.get("provisioner", None) is not None:
+ provisioner = await init_provisioner(
+ ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
+ )
+ else:
+ provisioner = await init_provisioner()
+
+ metric_logging_cfg = cfg.get("metric_logging", {})
+ mlogger = await get_or_create_metric_logger(process_name="Controller")
+ await mlogger.init_backends.call_one(metric_logging_cfg)
+
+ # Initialize torchstore
+ await ts.initialize(strategy=ts.ControllerStorageVolumes())
+
+ print("=" * 80)
+ print(f"Model: {model_name}")
+ print(f"Local batch size: {local_batch_size}")
+ print(
+ f"Sequence length: {request_len + response_len} ({request_len} + {response_len})"
+ )
+ print(f"Data parallel size: {dp_size}")
+ print(f"Is RDMA available? {rdma_enabled()}")
+ print("=" * 80 + "\n")
+
+ # Initialize trainer and generator
+ print("Initializing trainer and generator...")
+ init_start = time.time()
+
+ trainer, policy = await asyncio.gather(
+ RLTrainer.options(**cfg.actors.trainer).as_actor(
+ **cfg.trainer,
+ loss=lambda *args, **kwargs: torch.tensor(
+ 1.0, requires_grad=True, device="cuda"
+ ),
+ ),
+ Generator.options(**cfg.actors.policy).as_actor(**cfg.policy),
+ )
+
+ init_time = time.time() - init_start
+ print(f"Finished initialization in ({init_time:.2f}s)")
+
+ # Run one training step to create weight delta
+ print("Running single training step...")
+ step_start = time.time()
+
+ inputs, targets = generate_random_batch(
+ local_batch_size=local_batch_size,
+ request_len=request_len,
+ response_len=response_len,
+ vocab_size=vocab_size,
+ dp_size=dp_size,
+ )
+
+ await trainer.train_step.call(inputs, targets)
+ step_time = time.time() - step_start
+ print(f"Finished train step in ({step_time:.2f}s)\n")
+
+ # Test push_weights
+ print("Pushing weights from trainer to store...")
+ push_start = time.time()
+
+ await trainer.push_weights.call(policy_version=1)
+
+ push_time = time.time() - push_start
+ print(f"Finished weights push in ({push_time:.2f}s)\n")
+
+ # Test update_weights
+ print("Updating generator weights from store...")
+ update_start = time.time()
+
+ await policy.update_weights.call(version=1)
+
+ update_time = time.time() - update_start
+ print(f"Updated generator weights ({update_time:.2f}s)\n")
+
+ # TODO - ideally we have the capability to check forward passes between
+ # the trainer/generator to verify correctness. This would require adding
+ # forward capabilities to both trainer/generator actors.
+
+ # Summary
+ print("=" * 80)
+ print("Results")
+ print("=" * 80)
+ print(f"Push time: {push_time:.2f}s")
+ print(f"Update time: {update_time:.2f}s")
+ print(f"Total sync time: {push_time + update_time:.2f}s")
+ print("=" * 80 + "\n")
+
+ # Cleanup
+ print("Shutting down...")
+ await shutdown()
+ print("Shutdown complete.")
+
+
+if __name__ == "__main__":
+
+ @parse
+ def _main(cfg):
+ asyncio.run(main(cfg))
+
+ _main()
diff --git a/tests/sandbox/weight_sync/qwen3_1_7b.yaml b/tests/sandbox/weight_sync/qwen3_1_7b.yaml
new file mode 100644
index 000000000..e18589eaa
--- /dev/null
+++ b/tests/sandbox/weight_sync/qwen3_1_7b.yaml
@@ -0,0 +1,75 @@
+# Weight Sync Sandbox Configuration
+# >>> python -m tests.sandbox.weight_sync.main --config tests/sandbox/weight_sync/qwen3_1_7b.yaml
+
+model: "Qwen/Qwen3-1.7B"
+local_batch_size: 4
+max_req_tokens: 64
+max_res_tokens: 64
+compile: true # Enable torch.compile for trainer, and CUDA graphs for vLLM
+
+metric_logging:
+ console:
+ logging_mode: global_reduce
+
+policy:
+ prefetch_weights_to_shm: false # Disable to avoid shared memory warnings in test
+ engine_args:
+ model: ${model}
+ tensor_parallel_size: 1
+ pipeline_parallel_size: 1
+ enforce_eager: ${not:${compile}}
+ sampling_params:
+ n: 1
+ max_tokens: 32 # Just for verification forward pass
+ temperature: 1.0
+ top_p: 1.0
+
+trainer:
+ model:
+ name: qwen3
+ flavor: 1.7B
+ hf_assets_path: hf://${model}
+ optimizer:
+ name: AdamW
+ lr: 1e-5
+ eps: 1e-8
+ lr_scheduler:
+ warmup_steps: 1
+ training:
+ local_batch_size: ${local_batch_size}
+ seq_len: 128 # max_req_tokens + max_res_tokens
+ max_norm: 1.0
+ steps: 1 # We only run 1 step
+ dtype: bfloat16
+ gc_freq: 1
+ compile:
+ enable: ${compile}
+ parallelism:
+ data_parallel_replicate_degree: 1
+ data_parallel_shard_degree: 1 # Single GPU, no FSDP
+ tensor_parallel_degree: 1
+ pipeline_parallel_degree: 1
+ context_parallel_degree: 1
+ expert_parallel_degree: 1
+ disable_loss_parallel: true
+ checkpoint:
+ enable: true
+ folder: ./checkpoint
+ initial_load_path: hf://${model}
+ initial_load_in_hf: true
+ last_save_in_hf: true
+ async_mode: "disabled"
+ activation_checkpoint:
+ mode: selective
+ selective_ac_option: op
+
+# Resource allocation - both as actors
+actors:
+ policy:
+ procs: 1 # Single process for generator
+ with_gpus: true
+ mesh_name: policy
+ trainer:
+ procs: 1 # Single process for trainer
+ with_gpus: true
+ mesh_name: trainer
diff --git a/tests/unit_tests/data/test_metrics_aggregator.py b/tests/unit_tests/data/test_metrics_aggregator.py
deleted file mode 100644
index 5b847c92f..000000000
--- a/tests/unit_tests/data/test_metrics_aggregator.py
+++ /dev/null
@@ -1,456 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-"""
-Tests for MetricsAggregator functionality.
-
-This module tests the metrics collection and aggregation system including:
-- All aggregation types (SUM, MEAN, MAX, MIN, STATS, CATEGORICAL_COUNT)
-- State management and checkpointing
-- Multi-dataset metric namespacing
-- Distributed metrics aggregation
-- Metric consistency validation
-
-Uses synthetic metrics to verify correct aggregation behavior across scenarios.
-"""
-
-import logging
-
-import pytest
-import torch.distributed as dist
-
-from forge.data.dataset_metrics import AggregationType, Metric, MetricsAggregator
-from torch.testing._internal.common_fsdp import FSDPTest
-
-from tests.test_utils import gpu_test
-
-
-class TestMetricsAggregator:
- """Tests for MetricsAggregator core functionality and edge cases."""
-
- @pytest.mark.parametrize(
- "agg_type,test_values,expected",
- [
- (AggregationType.SUM, [1, 2, 3, 4], 10),
- (AggregationType.MEAN, [10, 20, 30, 40], 25.0),
- (AggregationType.MAX, [-5, 10, 3, 15], 15),
- (AggregationType.MIN, [5, -2, 8, 1], -2),
- (
- AggregationType.CATEGORICAL_COUNT,
- ["A", "B", "A", "C", "A"],
- {"A": 3, "B": 1, "C": 1},
- ),
- ],
- )
- def test_aggregation_types(self, agg_type, test_values, expected):
- """Tests each AggregationType with representative data to verify correct computation.
-
- Covers aggregation types:
- - SUM: Simple addition across values
- - MEAN: Average computation with proper count tracking
- - MAX/MIN: Extrema identification
- - CATEGORICAL_COUNT: Category frequency counting
- """
- aggregator = MetricsAggregator()
-
- metrics = [
- Metric(source="test", metric_name="metric", value=val, agg_type=agg_type)
- for val in test_values
- ]
- aggregator.update(metrics)
-
- result = aggregator.get_metrics_for_logging(prefix="train")
-
- if agg_type == AggregationType.CATEGORICAL_COUNT:
- for category, count in expected.items():
- assert result[f"train_test/metric_count_{category}"] == count
- else:
- assert result["train_test/metric"] == expected
-
- def test_stats_metrics(self):
- """Tests that STATS aggregation computes statistics (mean, min, max, percentiles)."""
- aggregator = MetricsAggregator()
- values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
-
- metrics = [
- Metric("test", "dist_metric", val, AggregationType.STATS) for val in values
- ]
- aggregator.update(metrics)
-
- result = aggregator.get_metrics_for_logging(prefix="train")
-
- assert result["train_test/dist_metric_stat_mean"] == 5.5
- assert result["train_test/dist_metric_stat_min"] == 1
- assert result["train_test/dist_metric_stat_max"] == 10
- assert result["train_test/dist_metric_stat_p50"] == 5.5
-
- def test_state_management(self):
- """Test metrics aggregator state persistence and restoration for checkpointing scenarios."""
- # Create aggregator with mixed metric types to test state saving
- aggregator1 = MetricsAggregator()
- initial_metrics = [
- Metric("ds1", "counter", 10, AggregationType.SUM),
- Metric("ds1", "average", 5.0, AggregationType.MEAN),
- Metric("ds2", "categories", "X", AggregationType.CATEGORICAL_COUNT),
- ]
- aggregator1.update(initial_metrics)
-
- # Save state
- state = aggregator1.state_dict()
-
- # Create new aggregator and restore state
- aggregator2 = MetricsAggregator()
- aggregator2.load_state_dict(state)
-
- # Both should have identical metrics
- metrics1 = aggregator1.get_metrics_for_logging(prefix="train")
- metrics2 = aggregator2.get_metrics_for_logging(prefix="train")
- assert metrics1 == metrics2
-
- # Continue updating both - should remain identical
- additional_metrics = [
- Metric("ds1", "counter", 5, AggregationType.SUM),
- Metric("ds1", "average", 15.0, AggregationType.MEAN),
- ]
- aggregator1.update(additional_metrics)
- aggregator2.update(additional_metrics)
-
- final_metrics1 = aggregator1.get_metrics_for_logging(prefix="train")
- final_metrics2 = aggregator2.get_metrics_for_logging(prefix="train")
- assert final_metrics1 == final_metrics2
-
- # Verify expected values
- assert final_metrics1["train_ds1/counter"] == 15 # 10 + 5
- assert final_metrics1["train_ds1/average"] == 10.0 # (5 + 15) / 2
-
- def test_multiple_datasets(self):
- """Test that metrics from multiple datasets are correctly namespaced."""
- aggregator = MetricsAggregator()
-
- metrics = [
- Metric("dataset1", "samples", 100, AggregationType.SUM),
- Metric("dataset2", "samples", 200, AggregationType.SUM),
- Metric("dataset1", "tokens", 1000, AggregationType.SUM),
- Metric("dataset2", "tokens", 2000, AggregationType.SUM),
- ]
- aggregator.update(metrics)
-
- result = aggregator.get_metrics_for_logging(prefix="train")
-
- assert result["train_dataset1/samples"] == 100
- assert result["train_dataset2/samples"] == 200
- assert result["train_dataset1/tokens"] == 1000
- assert result["train_dataset2/tokens"] == 2000
-
- def test_empty_aggregator(self):
- """Test that empty aggregator returns empty metrics."""
- aggregator = MetricsAggregator()
- result = aggregator.get_metrics_for_logging(prefix="train")
- assert result == {}
-
- def test_prefix_handling(self):
- """Test that prefix is correctly applied to metric keys."""
- aggregator = MetricsAggregator()
- metrics = [
- Metric("test_ds", "metric1", 42, AggregationType.SUM),
- Metric("test_ds", "metric2", 84, AggregationType.SUM),
- ]
- aggregator.update(metrics)
-
- # Test with prefix
- result_with_prefix = aggregator.get_metrics_for_logging(prefix="validation")
- assert result_with_prefix["validation_test_ds/metric1"] == 42
- assert result_with_prefix["validation_test_ds/metric2"] == 84
-
- # Test without prefix (uses default "data")
- result_no_prefix = aggregator.get_metrics_for_logging()
- assert result_no_prefix["data_test_ds/metric1"] == 42
- assert result_no_prefix["data_test_ds/metric2"] == 84
-
- def test_metric_consistency_validation(self):
- """Test that same metric name must use same aggregation type."""
- aggregator = MetricsAggregator()
-
- # First metric with SUM aggregation
- metrics1 = [Metric("test", "my_metric", 10, AggregationType.SUM)]
- aggregator.update(metrics1)
-
- # Try to use same metric name with different aggregation type - should fail
- metrics2 = [Metric("test", "my_metric", 5.0, AggregationType.MEAN)]
- with pytest.raises(
- ValueError, match="is already registered with aggregation type sum"
- ):
- aggregator.update(metrics2)
-
- # Same metric name with same aggregation type should work
- metrics3 = [Metric("test", "my_metric", 20, AggregationType.SUM)]
- aggregator.update(metrics3) # Should not raise
-
- result = aggregator.get_metrics_for_logging(prefix="train")
- assert result["train_test/my_metric"] == 30 # 10 + 20
-
- def test_metric_consistency_across_datasets(self):
- """Test that same metric name can use different aggregation types across different datasets."""
- aggregator = MetricsAggregator()
-
- # Same metric name but different datasets - should be allowed
- metrics = [
- Metric("dataset1", "metric", 10, AggregationType.SUM),
- Metric("dataset2", "metric", 5.0, AggregationType.MEAN),
- ]
- aggregator.update(metrics) # Should not raise
-
- result = aggregator.get_metrics_for_logging(prefix="train")
- assert result["train_dataset1/metric"] == 10
- assert result["train_dataset2/metric"] == 5.0
-
- def test_handler_generated_metric_validation(self):
- """Test that handler-generated metrics are validated for consistency."""
- aggregator = MetricsAggregator()
-
- # Create a user-defined metric that will conflict with stats
- user_metrics = [
- Metric("test", "dist_metric_stat_mean", 42, AggregationType.SUM)
- ]
- aggregator.update(user_metrics)
-
- # Now try to add a stats metric that will generate conflicting stat names
- dist_metrics = [Metric("test", "dist_metric", 10, AggregationType.STATS)]
- aggregator.update(dist_metrics)
-
- # This should fail when trying to get metrics for logging because the handler
- # will try to create "dist_metric_stat_mean" which conflicts with the user metric
- with pytest.raises(
- ValueError, match="is already registered with aggregation type sum"
- ):
- aggregator.get_metrics_for_logging(prefix="train")
-
- def test_handler_replacement_warning(self, caplog):
- """Test that replacing handlers in use generates a warning."""
- aggregator = MetricsAggregator()
-
- # Add a metric that uses SUM aggregation
- metrics = [Metric("test", "sum_metric", 10, AggregationType.SUM)]
- aggregator.update(metrics)
-
- # Replace the SUM handler - should generate warning
- from forge.data.dataset_metrics import SumAggHandler
-
- with caplog.at_level(logging.WARNING):
- aggregator.register_handler(AggregationType.SUM, SumAggHandler())
-
- # Check that the expected warning was logged
- assert len(caplog.records) == 1
- assert "Replacing handler for AggregationType.SUM" in caplog.records[0].message
-
-
-class TestDistributedMetricsAggregator(FSDPTest):
- """Distributed tests for MetricsAggregator using FSDPTest infrastructure."""
-
- @property
- def world_size(self) -> int:
- return 2
-
- @gpu_test(gpu_count=2)
- def test_distributed_all_aggregation_types(self):
- """
- Test that all aggregation types work correctly in distributed setting.
- Each rank contributes different values to ensure proper reduction across ranks.
- """
- aggregator = MetricsAggregator()
- rank = dist.get_rank()
-
- # Each rank contributes different values to test cross-rank aggregation
- base_value = (rank + 1) * 10 # rank 0: 10, rank 1: 20
-
- metrics = [
- Metric("test", "sum_metric", base_value, AggregationType.SUM),
- Metric("test", "mean_metric", base_value + 5, AggregationType.MEAN),
- Metric("test", "max_metric", base_value * 10, AggregationType.MAX),
- Metric("test", "min_metric", base_value // 2, AggregationType.MIN),
- ]
-
- # STATS: Each rank adds 5 values for statistics
- # rank 0: [0, 1, 2, 3, 4], rank 1: [10, 11, 12, 13, 14]
- for i in range(5):
- metrics.append(
- Metric("test", "dist_metric", rank * 10 + i, AggregationType.STATS)
- )
-
- # CATEGORICAL_COUNT: Different categories per rank to test counting
- # rank 0: 3 of cat_A, 2 of cat_B
- # rank 1: 1 of cat_A, 4 of cat_C
- if rank == 0:
- metrics.extend(
- [
- Metric(
- "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT
- ),
- Metric(
- "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT
- ),
- Metric(
- "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT
- ),
- Metric(
- "test", "cat_metric", "cat_B", AggregationType.CATEGORICAL_COUNT
- ),
- Metric(
- "test", "cat_metric", "cat_B", AggregationType.CATEGORICAL_COUNT
- ),
- ]
- )
- else:
- metrics.extend(
- [
- Metric(
- "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT
- ),
- Metric(
- "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT
- ),
- Metric(
- "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT
- ),
- Metric(
- "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT
- ),
- Metric(
- "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT
- ),
- ]
- )
-
- # Update aggregator and get results
- aggregator.update(metrics)
- result = aggregator.get_metrics_for_logging(prefix="train")
-
- # Verify aggregation results across all ranks
- # SUM: rank 0 adds 10, rank 1 adds 20 -> total 30
- # MEAN: rank 0 has 15, rank 1 has 25 -> avg 20
- # MAX: rank 0 has 100, rank 1 has 200 -> max 200
- # MIN: rank 0 has 5, rank 1 has 10 -> min 5
- assert result["train_test/sum_metric"] == 30
- assert result["train_test/mean_metric"] == 20
- assert result["train_test/max_metric"] == 200
- assert result["train_test/min_metric"] == 5
-
- # STATS: Combined values [0,1,2,3,4,10,11,12,13,14]
- # Mean should be average of local means: (2 + 12) / 2 = 7
- assert result["train_test/dist_metric_stat_mean"] == 7
- assert result["train_test/dist_metric_stat_min"] == 0
- assert result["train_test/dist_metric_stat_max"] == 14
-
- # CATEGORICAL_COUNT: Total counts across ranks
- # cat_A: 3(rank0) + 1(rank1) = 4, cat_B: 2(rank0) + 0(rank1) = 2, cat_C: 0(rank0) + 4(rank1) = 4
- assert result["train_test/cat_metric_count_cat_A"] == 4
- assert result["train_test/cat_metric_count_cat_B"] == 2
- assert result["train_test/cat_metric_count_cat_C"] == 4
-
- @gpu_test(gpu_count=2)
- def test_distributed_state_dict_resumption(self):
- """
- Test that MetricsAggregator state_dict save/restore works correctly in distributed setting.
- Verifies:
- - State can be saved after partial updates across ranks
- - State can be restored consistently across ranks
- - Continued updates after restore produce identical results
- - Distributed aggregation works correctly after restoration
- """
- rank = dist.get_rank()
-
- # Phase 1: Create aggregator and add initial metrics
- aggregator1 = MetricsAggregator()
-
- # Each rank contributes different initial values
- base_value = rank * 100 # rank 0: 0, rank 1: 100
-
- initial_metrics = [
- Metric("test", "sum_metric", base_value, AggregationType.SUM),
- Metric("test", "mean_metric", base_value // 2, AggregationType.MEAN),
- Metric("test", "max_metric", base_value * 2, AggregationType.MAX),
- ]
-
- # Add some STATS values - each rank adds 3 values
- for i in range(3):
- initial_metrics.append(
- Metric("test", "dist_metric", rank * 100 + i, AggregationType.STATS)
- )
-
- # Add CATEGORICAL_COUNT values
- if rank == 0:
- initial_metrics.extend(
- [
- Metric(
- "test",
- "cat_metric",
- "type_A",
- AggregationType.CATEGORICAL_COUNT,
- ),
- Metric(
- "test",
- "cat_metric",
- "type_A",
- AggregationType.CATEGORICAL_COUNT,
- ),
- ]
- )
- else:
- initial_metrics.extend(
- [
- Metric(
- "test",
- "cat_metric",
- "type_B",
- AggregationType.CATEGORICAL_COUNT,
- ),
- Metric(
- "test",
- "cat_metric",
- "type_B",
- AggregationType.CATEGORICAL_COUNT,
- ),
- Metric(
- "test",
- "cat_metric",
- "type_B",
- AggregationType.CATEGORICAL_COUNT,
- ),
- ]
- )
-
- aggregator1.update(initial_metrics)
-
- # Save state_dict after initial update
- state_dict = aggregator1.state_dict()
-
- # Phase 2: Create new aggregator and restore from state_dict
- aggregator2 = MetricsAggregator()
- aggregator2.load_state_dict(state_dict)
-
- # Verify both aggregators produce identical results after restore
- result1 = aggregator1.get_metrics_for_logging(prefix="checkpoint")
- result2 = aggregator2.get_metrics_for_logging(prefix="checkpoint")
- assert (
- result1 == result2
- ), f"Rank {rank}: Aggregators differ after state_dict restore"
-
- # Phase 3: Add more metrics to both aggregators
- additional_metrics = [
- Metric("test", "sum_metric", rank * 1000, AggregationType.SUM),
- Metric("test", "min_metric", rank * 1000, AggregationType.MIN),
- ]
-
- # Update both aggregators with additional metrics
- aggregator1.update(additional_metrics)
- aggregator2.update(additional_metrics)
-
- # Phase 4: Verify final results are identical across both aggregators
- final_result1 = aggregator1.get_metrics_for_logging(prefix="final")
- final_result2 = aggregator2.get_metrics_for_logging(prefix="final")
- assert (
- final_result1 == final_result2
- ), f"Rank {rank}: Final results differ after continued updates"
diff --git a/tests/unit_tests/data/test_metrics_transform.py b/tests/unit_tests/data/test_metrics_transform.py
deleted file mode 100644
index 078b511a8..000000000
--- a/tests/unit_tests/data/test_metrics_transform.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-"""
-Tests cover:
-- DefaultTrainingMetricTransform
-- Basic metric generation (samples_seen, tokens_seen, seq_len)
-- Dataset name validation and requirements
-- Proper metric type assignment and aggregation configuration
-"""
-
-import pytest
-
-from forge.data.dataset_metrics import AggregationType, DefaultTrainingMetricTransform
-
-
-class TestDefaultTrainingMetricTransform:
- """Tests for DefaultTrainingMetricTransform functionality."""
-
- def test_source_not_set_raises_error(self):
- """Test that the transform raises a RuntimeError if used before
- `set_source` is called, ensuring that metrics are always
- correctly attributed to a dataset."""
- transform = DefaultTrainingMetricTransform()
- sample = {"tokens": [1, 2, 3]}
-
- with pytest.raises(RuntimeError, match="set_source"):
- transform(sample)
-
- def test_basic_metrics_generation(self):
- """Test that transform generates expected training metrics for input samples."""
- transform = DefaultTrainingMetricTransform()
- # Set dataset name required for metric generation
- transform.set_source("test_dataset")
-
- sample = {"tokens": [1, 2, 3, 4, 5]}
- result = transform(sample)
-
- # Transform should preserve original sample data unchanged
- assert result["tokens"] == [1, 2, 3, 4, 5]
-
- # Should generate exactly 3 metrics: samples_seen, tokens_seen, seq_len
- assert "metrics" in result
- metrics = result["metrics"]
- assert len(metrics) == 3
-
- # Verify each metric has correct properties and aggregation type
- for metric in metrics:
- if metric.metric_name == "samples_seen":
- assert metric.source == "test_dataset"
- assert metric.value == 1
- assert metric.agg_type == AggregationType.SUM
-
- elif metric.metric_name == "tokens_seen":
- assert metric.source == "test_dataset"
- assert metric.value == 5
- assert metric.agg_type == AggregationType.SUM
-
- elif metric.metric_name == "seq_len":
- assert metric.source == "test_dataset"
- assert metric.value == 5
- assert metric.agg_type == AggregationType.STATS
diff --git a/tests/unit_tests/datasets/test_hf.py b/tests/unit_tests/datasets/test_hf.py
index c1535c8b8..802619361 100644
--- a/tests/unit_tests/datasets/test_hf.py
+++ b/tests/unit_tests/datasets/test_hf.py
@@ -26,14 +26,14 @@
import pytest
import torch.distributed as dist
-from forge.data.dataset_metrics import DefaultTrainingMetricTransform, MetricsAggregator
from forge.data.datasets import HfIterableDataset
+from forge.data.metric_transform import DefaultDatasetMetricTransform
+
+from tests.test_utils import gpu_test
from torch.testing._internal.common_fsdp import FSDPTest
from torchdata.stateful_dataloader import StatefulDataLoader
-from tests.test_utils import gpu_test
-
from .test_iterable_utils import collate_with_metrics, generate_ckpt
# Test Constants - Avoid perfect divisions
@@ -93,7 +93,7 @@ def _create_dataset(
dataset_name=dataset_name,
seed=SEED,
shuffle_buffer_size=10 if shuffle else 0,
- metric_transform=DefaultTrainingMetricTransform(),
+ metric_transform=DefaultDatasetMetricTransform(),
num_shards_per_rank=2,
**kwargs,
)
@@ -113,7 +113,7 @@ def test_default_dataset_name(self, small_dataset_file):
split="train",
# dataset_name not provided - should auto-generate
seed=SEED,
- metric_transform=DefaultTrainingMetricTransform(),
+ metric_transform=DefaultDatasetMetricTransform(),
num_shards_per_rank=4,
)
@@ -131,7 +131,7 @@ def test_default_dataset_name(self, small_dataset_file):
dataset_name="my_dataset",
weight=custom_weight,
seed=SEED,
- metric_transform=DefaultTrainingMetricTransform(),
+ metric_transform=DefaultDatasetMetricTransform(),
num_shards_per_rank=4,
)
@@ -149,17 +149,16 @@ def test_epoch_boundaries_and_checkpointing(
the epoch metric is correct, and checkpointing works as expected.
"""
- # 1. Setup Dataloaders and Aggregators for original and resumed runs
- def create_loader_and_aggregator():
+ # 1. Setup Dataloaders for original and resumed runs
+ def create_loader():
dataset = dataset_factory(small_dataset_file, shuffle=False)
loader = StatefulDataLoader(
dataset, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics
)
- aggregator = MetricsAggregator()
- return loader, aggregator
+ return loader
- loader1, aggregator1 = create_loader_and_aggregator()
- loader2, aggregator2 = create_loader_and_aggregator()
+ loader1 = create_loader()
+ loader2 = create_loader()
# 2. Calculate steps for the test run
total_samples = int(SMALL_DATASET_SIZE * num_epochs)
@@ -171,11 +170,9 @@ def create_loader_and_aggregator():
# 3. Generate checkpoint and resume
result = generate_ckpt(
loader1,
- aggregator1,
steps_before_checkpoint=steps_before_checkpoint,
steps_after_checkpoint=steps_after_checkpoint,
resume_dataloader=loader2,
- resume_aggregator=aggregator2,
)
# 4. Verify checkpointing and resumption
@@ -184,9 +181,10 @@ def create_loader_and_aggregator():
assert (
orig_post_ids == resumed_ids
), "Resumed batches should be identical for deterministic run"
+
assert (
- result["final_metrics"] == result["resumed_metrics"]
- ), "Final metrics should match"
+ result["post_checkpoint_metrics"] == result["resumed_metrics"]
+ ), "Resumed training should produce same metrics as original training"
def test_shuffling_behavior(self, dataset_factory, small_dataset_file):
"""Tests that shuffling changes data order between epochs but preserves the set of samples."""
@@ -233,10 +231,10 @@ def test_shuffling_behavior(self, dataset_factory, small_dataset_file):
# But should contain the same set of IDs
assert set(first_epoch_ids) == set(
range(SMALL_DATASET_SIZE)
- ), f"First epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {first_epoch_ids}"
+ ), f"First epoch samples should be (0-{SMALL_DATASET_SIZE - 1}), got {first_epoch_ids}"
assert set(second_epoch_ids) == set(
range(SMALL_DATASET_SIZE)
- ), f"Second epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {second_epoch_ids}"
+ ), f"Second epoch samples should be (0-{SMALL_DATASET_SIZE - 1}), got {second_epoch_ids}"
def test_epoch_tracking(self, dataset_factory, small_dataset_file):
"""Test that epoch number is correctly tracked across dataset restarts."""
@@ -253,9 +251,7 @@ def test_epoch_tracking(self, dataset_factory, small_dataset_file):
for sample in first_epoch_samples:
first_epoch_metrics.extend(sample["metrics"])
epoch_values = [
- metric.value
- for metric in first_epoch_metrics
- if metric.metric_name == "epoch"
+ metric.value for metric in first_epoch_metrics if "num_epochs" in metric.key
]
assert all(
epoch_value == 0 for epoch_value in epoch_values
@@ -268,14 +264,82 @@ def test_epoch_tracking(self, dataset_factory, small_dataset_file):
epoch_values = [
metric.value
for metric in second_epoch_metrics
- if metric.metric_name == "epoch"
+ if "num_epochs" in metric.key
]
assert all(
epoch_value == 1 for epoch_value in epoch_values
), f"Epoch values should be 1, got {epoch_values}"
+ def test_multiple_iter_calls_after_resume(
+ self, dataset_factory, small_dataset_file
+ ):
+ """Test that calling iter() multiple times after resuming restarts from checkpoint epoch.
+
+ 1. Resume from checkpoint at epoch 2
+ 2. Consume one epoch (now at epoch 3)
+ 3. Call iter(ds) again to create a new iterator
+ 4. The new iterator should restart from epoch 2 (checkpoint epoch), not 0 or 3
+
+ This ensures datasets can be re-iterated from their checkpoint state.
+ """
+ dataset = dataset_factory(small_dataset_file, shuffle=False)
+
+ # consume 2 epochs
+ it1 = iter(dataset)
+ samples = list(islice(it1, SMALL_DATASET_SIZE * 2))
+
+ # Save checkpoint after 2 epochs
+ state = dataset.state_dict()
+
+ # Continue training for 1 more epoch on the same iterator
+ more_samples = list(islice(it1, SMALL_DATASET_SIZE))
+
+ # Create a new dataset instance and load the checkpoint
+ dataset2 = dataset_factory(small_dataset_file, shuffle=False)
+ dataset2.load_state_dict(state)
+
+ # First iter() call should start from epoch 2 (the checkpoint epoch)
+ it2 = iter(dataset2)
+ first_iter_samples = list(islice(it2, SMALL_DATASET_SIZE))
+ first_iter_epochs = [
+ metric.value
+ for sample in first_iter_samples
+ for metric in sample["metrics"]
+ if "num_epochs" in metric.key
+ ]
+ assert all(
+ epoch == 2 for epoch in first_iter_epochs
+ ), f"First iter() should start at checkpoint epoch 2, got {set(first_iter_epochs)}"
+
+ # Consume one more epoch from the same iterator (now at epoch 3)
+ second_epoch_samples = list(islice(it2, SMALL_DATASET_SIZE))
+ second_epoch_epochs = [
+ metric.value
+ for sample in second_epoch_samples
+ for metric in sample["metrics"]
+ if "num_epochs" in metric.key
+ ]
+ assert all(
+ epoch == 3 for epoch in second_epoch_epochs
+ ), f"Second epoch should be 3, got {set(second_epoch_epochs)}"
+
+ # Call iter() again - it should restart from epoch 2, not continue from 4
+ it3 = iter(dataset2)
+ new_iter_samples = list(islice(it3, SMALL_DATASET_SIZE))
+ new_iter_epochs = [
+ metric.value
+ for sample in new_iter_samples
+ for metric in sample["metrics"]
+ if "num_epochs" in metric.key
+ ]
+ assert all(
+ epoch == 2 for epoch in new_iter_epochs
+ ), f"New iter() should restart from checkpoint epoch 2, got {set(new_iter_epochs)}"
+
class TestDistributedHfIterableDataset(FSDPTest):
+ """Test HfIterableDataset with 2-GPU distributed setup."""
+
@property
def world_size(self) -> int:
return 2
@@ -291,30 +355,20 @@ def test_distributed_epoch_boundary_checkpointing(self):
"""
rank = dist.get_rank()
- # Create shared temp directory (only rank 0 creates it)
- if rank == 0:
- temp_dir = tempfile.mkdtemp(prefix="epoch_test_")
- else:
- temp_dir = ""
-
- # Broadcast temp directory path to all ranks
- temp_dir_list = [temp_dir]
- dist.broadcast_object_list(temp_dir_list, src=0)
- temp_dir = temp_dir_list[0]
+ # Each rank creates its own local temp dir and files
+ temp_dir = tempfile.mkdtemp(prefix=f"epoch_test_rank{rank}_")
tmp_path = Path(temp_dir)
try:
medium_dataset_file = tmp_path / "medium_data.json"
- # Only rank 0 creates the data file, all ranks read from it
- if rank == 0:
- create_test_json_file(medium_dataset_file, MEDIUM_DATASET_SIZE)
- dist.barrier() # Wait for file creation
+ # Each rank creates its own file
+ create_test_json_file(medium_dataset_file, MEDIUM_DATASET_SIZE)
# Test multiple epoch boundaries
for num_epochs in [0.9, 1.0, 2.5]:
- def create_loader_and_aggregator():
+ def create_loader():
dataset = HfIterableDataset(
path="json",
data_files=str(medium_dataset_file),
@@ -322,7 +376,7 @@ def create_loader_and_aggregator():
dataset_name="epoch_test",
seed=SEED,
shuffle_buffer_size=0, # No shuffle for determinism
- metric_transform=DefaultTrainingMetricTransform(),
+ metric_transform=DefaultDatasetMetricTransform(),
num_shards_per_rank=2,
)
loader = StatefulDataLoader(
@@ -331,10 +385,10 @@ def create_loader_and_aggregator():
collate_fn=collate_with_metrics,
num_workers=0,
)
- return loader, MetricsAggregator()
+ return loader
- loader1, aggregator1 = create_loader_and_aggregator()
- loader2, aggregator2 = create_loader_and_aggregator()
+ loader1 = create_loader()
+ loader2 = create_loader()
# Calculate steps to reach desired epoch boundary
samples_per_rank = MEDIUM_DATASET_SIZE // dist.get_world_size()
@@ -352,11 +406,9 @@ def create_loader_and_aggregator():
result = generate_ckpt(
loader1,
- aggregator1,
- steps_before,
- steps_after,
+ steps_before_checkpoint=steps_before,
+ steps_after_checkpoint=steps_after,
resume_dataloader=loader2,
- resume_aggregator=aggregator2,
)
# Verify deterministic resumption - critical for distributed training
@@ -375,10 +427,129 @@ def create_loader_and_aggregator():
num_epochs - 1e-9
) # -1e-9 so 1.0 epochs -> 0
assert (
- final_metrics["train_epoch_test/num_epochs"] == expected_epoch
+ final_metrics["dataset/epoch_test/num_epochs"] == expected_epoch
), f"Epoch count incorrect for {num_epochs} epochs test scenario"
finally:
- # Clean up temp directory (only rank 0)
+ shutil.rmtree(temp_dir)
+
+
+class TestDPShardingWithTP(FSDPTest):
+ """Test DP sharding with TP replication (4-GPU setup)."""
+
+ @property
+ def world_size(self) -> int:
+ return 4
+
+ @gpu_test(gpu_count=4)
+ def test_dp_sharding_with_tp_replication(self):
+ """Verify DP sharding works correctly with TP/CP replication.
+
+ This is a CRITICAL test that validates the core bug fix:
+ - Previously: Each rank got different batches (incorrect)
+ - Now: TP/CP ranks within same DP group get identical batches (correct)
+
+ Setup: DP=2, TP=2 (4 GPUs total)
+ - DP group 0: ranks [0, 1] - should see SAME batches (TP replication)
+ - DP group 1: ranks [2, 3] - should see SAME batches (TP replication)
+ - DP group 0 vs 1: should see DIFFERENT batches (DP sharding)
+
+ Mesh structure:
+ - TP rank 0 DP replicas: [0, 2] - shard across these
+ - TP rank 1 DP replicas: [1, 3] - shard across these
+ """
+ import hashlib
+
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ temp_dir = tempfile.mkdtemp(prefix=f"dp_tp_test_rank{rank}_")
+
+ try:
+ data_file = Path(temp_dir) / "data.json"
+ # Create dataset with enough samples for clear sharding
+ # 40 samples / 2 DP groups = 20 samples per DP group
+ create_test_json_file(data_file, MEDIUM_DATASET_SIZE, offset=0)
+
+ # Create DP mesh for sharding
+ # Key insight: Create groups across DP replicas for each TP rank
+ # TP rank = rank % 2, so:
+ # - TP rank 0: ranks [0, 2] (one from each DP group)
+ # - TP rank 1: ranks [1, 3] (one from each DP group)
+ tp_rank = rank % 2
+ tp_world_size = 2
+ dp_world_size = world_size // tp_world_size
+
+ # Create DP groups for each TP rank
+ dp_groups = []
+ for tp_r in range(tp_world_size):
+ # Ranks for this TP rank across DP groups
+ ranks = [tp_r + i * tp_world_size for i in range(dp_world_size)]
+ group = dist.new_group(ranks=ranks)
+ dp_groups.append(group)
+
+ dp_mesh = dp_groups[tp_rank]
+
+ # - Rank 0 (tp_rank=0) uses group [0, 2], gets rank=0 → shard 0
+ # - Rank 1 (tp_rank=1) uses group [1, 3], gets rank=0 → shard 0
+ # - Rank 2 (tp_rank=0) uses group [0, 2], gets rank=1 → shard 1
+ # - Rank 3 (tp_rank=1) uses group [1, 3], gets rank=1 → shard 1
+
+ dataset = HfIterableDataset(
+ path="json",
+ data_files=str(data_file),
+ split="train",
+ dataset_name="dp_tp_test",
+ shuffle_buffer_size=0,
+ metric_transform=DefaultDatasetMetricTransform(),
+ num_shards_per_rank=2,
+ dp_mesh=dp_mesh, # CRITICAL: Pass dp_mesh for correct sharding
+ )
+
+ dataloader = StatefulDataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ collate_fn=collate_with_metrics,
+ num_workers=0,
+ )
+
+ # Collect batches and compute hashes
+ batches = list(islice(iter(dataloader), 5))
+ batch_hashes = []
+ for batch in batches:
+ # Hash the batch IDs to verify identity/difference
+ batch_ids = batch["id"].cpu().tolist()
+ batch_hash = hashlib.md5(str(batch_ids).encode()).hexdigest()
+ batch_hashes.append(batch_hash)
+
+ # Gather hashes from all ranks for comparison
+ gathered_hashes = [None] * world_size
+ dist.all_gather_object(gathered_hashes, batch_hashes)
+
if rank == 0:
- shutil.rmtree(temp_dir)
+ # Verify TP replication within DP groups
+ # Ranks 0 and 1 should have identical hashes (same DP group)
+ assert gathered_hashes[0] == gathered_hashes[1], (
+ f"Ranks 0 and 1 (same DP group) should see identical batches!\n"
+ f"Rank 0 hashes: {gathered_hashes[0][:3]}...\n"
+ f"Rank 1 hashes: {gathered_hashes[1][:3]}..."
+ )
+
+ # Ranks 2 and 3 should have identical hashes (same DP group)
+ assert gathered_hashes[2] == gathered_hashes[3], (
+ f"Ranks 2 and 3 (same DP group) should see identical batches!\n"
+ f"Rank 2 hashes: {gathered_hashes[2][:3]}...\n"
+ f"Rank 3 hashes: {gathered_hashes[3][:3]}..."
+ )
+
+ # Verify DP sharding across groups
+ # Ranks 0/1 should see DIFFERENT batches from ranks 2/3
+ assert gathered_hashes[0] != gathered_hashes[2], (
+ f"Ranks 0 and 2 (different DP groups) should see different batches!\n"
+ f"DP group 0 hashes: {gathered_hashes[0][:3]}...\n"
+ f"DP group 1 hashes: {gathered_hashes[2][:3]}..."
+ )
+
+ dist.barrier()
+
+ finally:
+ shutil.rmtree(temp_dir)
diff --git a/tests/unit_tests/datasets/test_interleaved.py b/tests/unit_tests/datasets/test_interleaved.py
index 0073b905e..b4cee0ce3 100644
--- a/tests/unit_tests/datasets/test_interleaved.py
+++ b/tests/unit_tests/datasets/test_interleaved.py
@@ -28,13 +28,13 @@
import torch
import torch.distributed as dist
-
-from forge.data.dataset_metrics import DefaultTrainingMetricTransform, MetricsAggregator
from forge.data.datasets import HfIterableDataset, InterleavedDataset
-from torch.testing._internal.common_fsdp import FSDPTest
-from torchdata.stateful_dataloader import StatefulDataLoader
+
+from forge.data.metric_transform import DefaultDatasetMetricTransform
from tests.test_utils import gpu_test
+from torch.testing._internal.common_fsdp import FSDPTest
+from torchdata.stateful_dataloader import StatefulDataLoader
# Import test utilities
from .test_iterable_utils import collate_with_metrics, generate_ckpt
@@ -114,7 +114,7 @@ def _create_dataset(
dataset_name=dataset_name,
seed=SEED,
shuffle_buffer_size=10 if shuffle else 0,
- metric_transform=DefaultTrainingMetricTransform(),
+ metric_transform=DefaultDatasetMetricTransform(),
num_shards_per_rank=2,
**kwargs,
)
@@ -299,37 +299,47 @@ def test_metrics_aggregation(
[child_interleaved, ds3], seed=SEED, dataset_name="parent"
)
- aggregator = MetricsAggregator()
+ # Collect metrics
+ collected_metrics = []
# Process some samples
total_samples = 300
for sample in islice(iter(parent_interleaved), total_samples):
- aggregator.update(sample["metrics"])
-
- metrics = aggregator.get_metrics_for_logging(prefix="train")
-
- # Should have metrics from all three datasets, with flat keys
- assert "train_ds1/samples_seen" in metrics
- assert "train_ds2/samples_seen" in metrics
- assert "train_ds3/samples_seen" in metrics
+ if "metrics" in sample:
+ collected_metrics.extend(sample["metrics"])
+
+ # Count metrics by dataset name
+ ds1_samples_processed = sum(
+ 1
+ for m in collected_metrics
+ if hasattr(m, "key") and "dataset/ds1/samples_processed" in m.key
+ )
+ ds2_samples_processed = sum(
+ 1
+ for m in collected_metrics
+ if hasattr(m, "key") and "dataset/ds2/samples_processed" in m.key
+ )
+ ds3_samples_processed = sum(
+ 1
+ for m in collected_metrics
+ if hasattr(m, "key") and "dataset/ds3/samples_processed" in m.key
+ )
# All datasets should have contributed samples
- assert metrics["train_ds1/samples_seen"] > 0
- assert metrics["train_ds2/samples_seen"] > 0
- assert metrics["train_ds3/samples_seen"] > 0
+ assert ds1_samples_processed > 0, "ds1 should have contributed samples"
+ assert ds2_samples_processed > 0, "ds2 should have contributed samples"
+ assert ds3_samples_processed > 0, "ds3 should have contributed samples"
# Total samples should equal what we processed
calculated_total_samples = (
- metrics["train_ds1/samples_seen"]
- + metrics["train_ds2/samples_seen"]
- + metrics["train_ds3/samples_seen"]
+ ds1_samples_processed + ds2_samples_processed + ds3_samples_processed
)
assert calculated_total_samples == total_samples
# Test that ratios are approximately correct based on nested weighting
- ds1_ratio = metrics["train_ds1/samples_seen"] / total_samples
- ds2_ratio = metrics["train_ds2/samples_seen"] / total_samples
- ds3_ratio = metrics["train_ds3/samples_seen"] / total_samples
+ ds1_ratio = ds1_samples_processed / total_samples
+ ds2_ratio = ds2_samples_processed / total_samples
+ ds3_ratio = ds3_samples_processed / total_samples
# Expected ratios based on nested weighting:
# Inner weights: ds1=0.2, ds2=0.8 -> inner total=1.0
@@ -377,32 +387,30 @@ def create_interleaved():
loader1 = StatefulDataLoader(
interleaved1, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics
)
- aggregator1 = MetricsAggregator()
# Resumed run
interleaved2 = create_interleaved()
loader2 = StatefulDataLoader(
interleaved2, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics
)
- aggregator2 = MetricsAggregator()
result = generate_ckpt(
loader1,
- aggregator1,
steps_before_checkpoint=10,
steps_after_checkpoint=20,
resume_dataloader=loader2,
- resume_aggregator=aggregator2,
)
+ # Verify checkpointing and resumption work correctly
+ # After loading a checkpoint, training should continue identically
orig_post_ids = [b["id"].tolist() for b in result["post_checkpoint_batches"]]
resumed_ids = [b["id"].tolist() for b in result["resumed_batches"]]
assert (
orig_post_ids == resumed_ids
), "Resumed batches should be identical for deterministic run"
assert (
- result["final_metrics"] == result["resumed_metrics"]
- ), "Final metrics should match"
+ result["post_checkpoint_metrics"] == result["resumed_metrics"]
+ ), "Resumed training should produce same metrics as original training"
# Test sampling log functionality
# Check that sampling log contains tuples of (iteration_count, dataset_name)
@@ -476,43 +484,36 @@ def test_distributed_interleaved_checkpointing(self):
"""
rank = dist.get_rank()
- # Create shared temp directory (only rank 0 creates it)
- if rank == 0:
- temp_dir = tempfile.mkdtemp(prefix="interleaved_test_")
- else:
- temp_dir = None
-
- # Broadcast temp directory to all ranks
- temp_dir_list = [temp_dir] if temp_dir is not None else [""]
- dist.broadcast_object_list(temp_dir_list, src=0)
- temp_dir = temp_dir_list[0]
+ # Each rank creates its own local temp dir and files (no broadcast/barrier needed for creation)
+ temp_dir = tempfile.mkdtemp(prefix=f"interleaved_test_rank{rank}_")
tmp_path = Path(temp_dir)
try:
-
- def create_dataset():
- file1 = tmp_path / "ds1.json"
- file2 = tmp_path / "ds2.json"
- file3 = tmp_path / "ds3.json"
-
- # Only rank 0 creates the data files
- if rank == 0:
- create_test_json_file(file1, SMALL_DATASET_SIZE) # IDs 0-22
- create_test_json_file(
- file2, MEDIUM_DATASET_SIZE, offset=100
- ) # IDs 100-134
- create_test_json_file(
- file3, LARGE_DATASET_SIZE, offset=1000
- ) # IDs 1000-1046
- dist.barrier() # Wait for file creation
-
+ # ============================================
+ # SETUP: Each rank creates its own test files
+ # ============================================
+ file1 = tmp_path / "ds1.json"
+ file2 = tmp_path / "ds2.json"
+ file3 = tmp_path / "ds3.json"
+
+ create_test_json_file(file1, SMALL_DATASET_SIZE, offset=0)
+ create_test_json_file(file2, MEDIUM_DATASET_SIZE, offset=100)
+ create_test_json_file(file3, LARGE_DATASET_SIZE, offset=1000)
+
+ # No barrier needed since files are local to each rank
+
+ # ============================================
+ # TEST LOGIC: Functions that use the files
+ # ============================================
+ def create_dataset() -> InterleavedDataset:
+ """Create interleaved dataset from local files."""
ds1 = HfIterableDataset(
path="json",
data_files=str(file1),
split="train",
dataset_name="ds1",
- shuffle_buffer_size=0, # No shuffle for determinism
- metric_transform=DefaultTrainingMetricTransform(),
+ shuffle_buffer_size=0,
+ metric_transform=DefaultDatasetMetricTransform(),
num_shards_per_rank=2,
weight=0.3,
)
@@ -521,8 +522,8 @@ def create_dataset():
data_files=str(file2),
split="train",
dataset_name="ds2",
- shuffle_buffer_size=0, # No shuffle for determinism
- metric_transform=DefaultTrainingMetricTransform(),
+ shuffle_buffer_size=0,
+ metric_transform=DefaultDatasetMetricTransform(),
num_shards_per_rank=2,
weight=0.7,
)
@@ -531,8 +532,8 @@ def create_dataset():
data_files=str(file3),
split="train",
dataset_name="ds3",
- shuffle_buffer_size=0, # No shuffle for determinism
- metric_transform=DefaultTrainingMetricTransform(),
+ shuffle_buffer_size=0,
+ metric_transform=DefaultDatasetMetricTransform(),
num_shards_per_rank=2,
weight=1.0,
)
@@ -552,19 +553,17 @@ def create_dataloader(dataset):
num_workers=0, # Avoid multiprocessing in distributed tests
collate_fn=collate_with_metrics,
)
- return loader, MetricsAggregator()
+ return loader
# Run checkpointing test with small number of steps
- loader1, aggregator1 = create_dataloader(create_dataset())
- loader2, aggregator2 = create_dataloader(create_dataset())
+ loader1 = create_dataloader(create_dataset())
+ loader2 = create_dataloader(create_dataset())
result = generate_ckpt(
loader1,
- aggregator1,
- 3,
- 3, # 3 steps before, 3 steps after checkpoint
+ steps_before_checkpoint=3,
+ steps_after_checkpoint=3,
resume_dataloader=loader2,
- resume_aggregator=aggregator2,
)
# Verify deterministic resumption
@@ -577,8 +576,8 @@ def create_dataloader(dataset):
f"This indicates sampling state is not properly preserved."
)
assert (
- result["final_metrics"] == result["resumed_metrics"]
- ), "Final metrics don't match resumed metrics - aggregator state issue"
+ result["post_checkpoint_metrics"] == result["resumed_metrics"]
+ ), "Resumed training should produce same metrics as original training"
# Verify sampling ratio is approximately maintained for nested structure
all_ids = []
@@ -621,6 +620,5 @@ def create_dataloader(dataset):
), f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}"
finally:
- # Clean up temp directory (only rank 0)
- if rank == 0:
- shutil.rmtree(temp_dir)
+ # Each rank cleans its own temp dir
+ shutil.rmtree(temp_dir)
diff --git a/tests/unit_tests/datasets/test_iterable_utils.py b/tests/unit_tests/datasets/test_iterable_utils.py
index cdeced7c7..0c6d26fe3 100644
--- a/tests/unit_tests/datasets/test_iterable_utils.py
+++ b/tests/unit_tests/datasets/test_iterable_utils.py
@@ -7,92 +7,91 @@
from typing import Any, Optional
import torch
-from forge.data.dataset_metrics import MetricsAggregator
-
from torch.utils.data import DataLoader
-def collate_with_metrics(batch: list[dict[str, Any]]) -> dict[str, Any]:
- """Simple collate that extracts metrics and pads tokens."""
- all_metrics = []
- clean_batch = []
+def collate_with_metrics(batch):
+ """
+ Simple collate function that preserves metrics for validation.
+ Collects metrics from all samples in the batch and aggregates them.
+
+ Uses a simple collation that doesn't enforce same sizes for lists/tokens.
+ """
+ # Collect metrics from all samples
+ batch_metrics = []
for sample in batch:
if "metrics" in sample:
- all_metrics.extend(sample.pop("metrics"))
- clean_batch.append(sample)
-
- if not clean_batch:
- return {"metrics": all_metrics}
-
- # Simple padding for tokens
- ids = torch.tensor([item["id"] for item in clean_batch])
- tokens = torch.nn.utils.rnn.pad_sequence(
- [torch.tensor(item["tokens"]) for item in clean_batch],
- batch_first=True,
- padding_value=-1, # Use -1 for padding to distinguish from valid IDs
- )
- collated = {
- "id": ids,
- "tokens": tokens,
- }
-
- # Add text field for non-tensor data
- if "text" in clean_batch[0]:
- collated["text"] = [item["text"] for item in clean_batch]
+ batch_metrics.extend(sample.pop("metrics"))
+
+ # Simple collation that handles variable-length sequences
+ collated = {}
+ if batch:
+ for key in batch[0].keys():
+ values = [sample[key] for sample in batch]
+ if key == "tokens" or key == "labels":
+ # Keep as list of lists for variable length sequences
+ collated[key] = values
+ else:
+ # Use default collation for scalars
+ collated[key] = torch.utils.data.default_collate(values)
+
+ # Add batch-level metrics key for downstream processing
+ if batch_metrics:
+ collated["metrics"] = batch_metrics
- collated["metrics"] = all_metrics
return collated
def generate_ckpt(
dataloader: DataLoader,
- aggregator: MetricsAggregator,
steps_before_checkpoint: int,
steps_after_checkpoint: int,
resume_dataloader: Optional[DataLoader] = None,
- resume_aggregator: Optional[MetricsAggregator] = None,
) -> dict[str, Any]:
"""
Generates a checkpoint by running through data and saving checkpoint mid-stream.
- Optionally, a second dataloader and aggregator can be given to resume from ckpt
+ Optionally, a second dataloader can be given to resume from checkpoint
and run steps_after_checkpoint to match the first one.
+ Collects and aggregates metrics for test validation purposes.
+
Args:
dataloader (DataLoader): The dataloader to test
- aggregator (MetricsAggregator): The metrics aggregator to use
steps_before_checkpoint (int): Number of steps to run before saving checkpoint
steps_after_checkpoint (int): Number of steps to run after checkpoint
resume_dataloader (Optional[DataLoader]): Optional new dataloader to test resuming.
If None, returns empty resumed_batches.
- resume_aggregator (Optional[MetricsAggregator]): Optional new aggregator to test resuming.
- If None, returns empty resumed_metrics.
Returns:
- dict[str, Any]: Dict with batches/metrics from both pre and post checkpoint runs.
+ dict[str, Any]: Dict with batches and aggregated metrics for validation.
"""
iterator = iter(dataloader)
- # Collect batches before and after checkpoint
+ # Collect batches and metrics before and after checkpoint
batches = []
+ all_metrics = [] # All metrics collected during the run
+ checkpoint_metrics = [] # Metrics collected only up to checkpoint
checkpoint_state = None
- metrics_at_checkpoint = {}
total_steps = steps_before_checkpoint + steps_after_checkpoint
for idx, batch in enumerate(iterator):
batches.append(batch)
- # Process metrics
+ # Collect metrics for test validation
if "metrics" in batch:
- aggregator.update(batch.pop("metrics"))
+ batch_metrics = batch.pop("metrics")
+ all_metrics.extend(batch_metrics)
+
+ # If we haven't reached checkpoint yet, also add to checkpoint metrics
+ if idx < steps_before_checkpoint:
+ checkpoint_metrics.extend(batch_metrics)
# Save checkpoint state after steps_before_checkpoint
if idx == steps_before_checkpoint - 1: # -1 because idx is 0-based
checkpoint_state = {
"loader": dataloader.state_dict(),
- "aggregator": aggregator.state_dict(),
}
- metrics_at_checkpoint = aggregator.get_metrics_for_logging(prefix="train")
# Stop after total steps
if idx == total_steps - 1:
@@ -102,43 +101,56 @@ def generate_ckpt(
pre_checkpoint_batches = batches[:steps_before_checkpoint]
post_checkpoint_batches = batches[steps_before_checkpoint:]
- # Resume with new instances if provided
+ # Compute metrics for post-checkpoint batches only
+ post_checkpoint_metrics = all_metrics[len(checkpoint_metrics) :]
+
+ # Resume with new instance if provided
resumed_batches = []
- resumed_metrics = {}
-
- if (
- resume_dataloader is not None
- and resume_aggregator is not None
- and checkpoint_state is not None
- ):
- # Test resuming with new instances
+ resumed_metrics = []
+
+ if resume_dataloader is not None and checkpoint_state is not None:
+ # Test resuming with new instance
resume_dataloader.load_state_dict(checkpoint_state["loader"])
- resume_aggregator.load_state_dict(checkpoint_state["aggregator"])
resume_iterator = iter(resume_dataloader)
# Collect only the post-checkpoint batches when resuming
for idx, batch in enumerate(resume_iterator):
resumed_batches.append(batch)
- # Process metrics
+ # Collect metrics from resumed batches
if "metrics" in batch:
- resume_aggregator.update(batch.pop("metrics"))
+ batch_metrics = batch.pop("metrics")
+ resumed_metrics.extend(batch_metrics)
# Stop after steps_after_checkpoint
if idx == steps_after_checkpoint - 1:
break
- resumed_metrics = resume_aggregator.get_metrics_for_logging(prefix="train")
-
return {
# Original run
"pre_checkpoint_batches": pre_checkpoint_batches,
"post_checkpoint_batches": post_checkpoint_batches,
- "metrics_at_checkpoint": metrics_at_checkpoint,
- "final_metrics": aggregator.get_metrics_for_logging(prefix="train"),
+ "metrics_at_checkpoint": aggregate_metrics(checkpoint_metrics),
+ "post_checkpoint_metrics": aggregate_metrics(post_checkpoint_metrics),
+ "final_metrics": aggregate_metrics(all_metrics),
# Resumed run
"resumed_batches": resumed_batches,
- "resumed_metrics": resumed_metrics,
+ "resumed_metrics": aggregate_metrics(resumed_metrics),
# Internal state for loading - only if someone needs to manually load
"_checkpoint_state": checkpoint_state,
}
+
+
+def aggregate_metrics(metrics_list: list) -> dict[str, Any]:
+ if not metrics_list:
+ return {}
+
+ accumulators = {}
+
+ for metric in metrics_list:
+ key = metric.key
+ if key not in accumulators:
+ accumulators[key] = metric.reduction.accumulator_class(metric.reduction)
+ accumulators[key].append(metric.value)
+
+ return {key: acc.get_value() for key, acc in accumulators.items()}
diff --git a/tests/unit_tests/datasets/test_packed.py b/tests/unit_tests/datasets/test_packed.py
index 352fbf703..2a951b27d 100644
--- a/tests/unit_tests/datasets/test_packed.py
+++ b/tests/unit_tests/datasets/test_packed.py
@@ -13,8 +13,8 @@
import pytest
import torch
-from forge.data.collate import collate_packed
-from forge.data.dataset_metrics import MetricsAggregator
+from forge.data import CROSS_ENTROPY_IGNORE_IDX
+from forge.data.collate import collate_packed, collate_padded
from forge.data.datasets import HfIterableDataset
from forge.data.datasets.packed import (
_SUPPORTS_FLEX_ATTENTION,
@@ -914,7 +914,7 @@ def test_checkpoint_and_resume(self, dataset_factory):
batch_size = 1
# Setup dataset factory
- def create_loader_and_aggregator():
+ def create_loader():
dataset = dataset_factory(samples)
packer = TextPacker(padding_idx=999, ignore_idx=-100)
packed_dataset = PackedDataset(
@@ -931,11 +931,10 @@ def create_loader_and_aggregator():
loader = StatefulDataLoader(
packed_dataset, batch_size=batch_size, collate_fn=collate_fn
)
- aggregator = MetricsAggregator()
- return loader, aggregator
+ return loader
- loader1, aggregator1 = create_loader_and_aggregator()
- loader2, aggregator2 = create_loader_and_aggregator()
+ loader1 = create_loader()
+ loader2 = create_loader()
steps_before_checkpoint = 2
steps_after_checkpoint = 2
@@ -943,13 +942,236 @@ def create_loader_and_aggregator():
# Generate checkpoint and resume
result = generate_ckpt(
loader1,
- aggregator1,
steps_before_checkpoint=steps_before_checkpoint,
steps_after_checkpoint=steps_after_checkpoint,
resume_dataloader=loader2,
- resume_aggregator=aggregator2,
)
# Verify that checkpointing and resumption work
assert len(result["post_checkpoint_batches"]) == steps_after_checkpoint
assert len(result["resumed_batches"]) == steps_after_checkpoint
+
+ def test_iter_restart_determinism(self, dataset_factory):
+ """Test that calling iter() multiple times produces deterministic results.
+
+ This is critical for evaluation: each eval run should start from the
+ same state (epoch 0, step 0) regardless of previous iterations.
+ """
+ samples = [
+ {"tokens": [0] * 3},
+ {"tokens": [1] * 2},
+ {"tokens": [2] * 4},
+ ]
+ target_tokens_per_pack = 6
+
+ # Create packed dataset
+ dataset = dataset_factory(samples)
+ packer = TextPacker(padding_idx=999, ignore_idx=-100)
+ packed_dataset = PackedDataset(
+ dataset=dataset,
+ packer=packer,
+ target_tokens_per_pack=target_tokens_per_pack,
+ buffer_size=1,
+ )
+
+ # First iteration - get first 2 packs
+ iter1 = iter(packed_dataset)
+ packs_iter1 = list(islice(iter1, 2))
+
+ # Second iteration - should get same first 2 packs
+ iter2 = iter(packed_dataset)
+ packs_iter2 = list(islice(iter2, 2))
+
+ # Verify both iterations produce identical packs
+ assert len(packs_iter1) == len(packs_iter2) == 2
+
+ for i, (pack1, pack2) in enumerate(zip(packs_iter1, packs_iter2)):
+ torch.testing.assert_close(
+ pack1["tokens"],
+ pack2["tokens"],
+ msg=f"Pack {i}: tokens mismatch between iterations",
+ )
+ torch.testing.assert_close(
+ pack1["document_ids"],
+ pack2["document_ids"],
+ msg=f"Pack {i}: document_ids mismatch between iterations",
+ )
+
+
+class TestCollatePadded:
+ """Test collate_padded function"""
+
+ def test_empty_batch(self):
+ """Test collating an empty batch"""
+ result = collate_padded([])
+ assert result == {}
+
+ def test_single_sample(self):
+ """Test collating a single sample"""
+ batch = [
+ {
+ "tokens": torch.tensor([1, 2, 3]),
+ "labels": torch.tensor([4, 5, 6]),
+ }
+ ]
+ result = collate_padded(batch)
+
+ assert result["tokens"].shape == (1, 3)
+ assert result["labels"].shape == (1, 3)
+ torch.testing.assert_close(result["tokens"], torch.tensor([[1, 2, 3]]))
+ torch.testing.assert_close(result["labels"], torch.tensor([[4, 5, 6]]))
+
+ def test_equal_length_samples(self):
+ """Test collating samples with equal lengths"""
+ batch = [
+ {
+ "tokens": torch.tensor([1, 2, 3]),
+ "labels": torch.tensor([4, 5, 6]),
+ },
+ {
+ "tokens": torch.tensor([7, 8, 9]),
+ "labels": torch.tensor([10, 11, 12]),
+ },
+ ]
+ result = collate_padded(batch)
+
+ assert result["tokens"].shape == (2, 3)
+ assert result["labels"].shape == (2, 3)
+ torch.testing.assert_close(
+ result["tokens"], torch.tensor([[1, 2, 3], [7, 8, 9]])
+ )
+ torch.testing.assert_close(
+ result["labels"], torch.tensor([[4, 5, 6], [10, 11, 12]])
+ )
+
+ def test_padding_to_longest(self):
+ """Test padding shorter sequences to the longest in batch"""
+ batch = [
+ {
+ "tokens": torch.tensor([1, 2]),
+ "labels": torch.tensor([3, 4]),
+ },
+ {
+ "tokens": torch.tensor([5, 6, 7, 8]),
+ "labels": torch.tensor([9, 10, 11, 12]),
+ },
+ {
+ "tokens": torch.tensor([13, 14, 15]),
+ "labels": torch.tensor([16, 17, 18]),
+ },
+ ]
+ result = collate_padded(batch)
+
+ # All should be padded to length 4 (longest)
+ assert result["tokens"].shape == (3, 4)
+ assert result["labels"].shape == (3, 4)
+
+ # Check tokens padding (padded with 0)
+ torch.testing.assert_close(
+ result["tokens"],
+ torch.tensor([[1, 2, 0, 0], [5, 6, 7, 8], [13, 14, 15, 0]]),
+ )
+
+ # Check labels padding (padded with CROSS_ENTROPY_IGNORE_IDX)
+ torch.testing.assert_close(
+ result["labels"],
+ torch.tensor(
+ [
+ [3, 4, CROSS_ENTROPY_IGNORE_IDX, CROSS_ENTROPY_IGNORE_IDX],
+ [9, 10, 11, 12],
+ [16, 17, 18, CROSS_ENTROPY_IGNORE_IDX],
+ ]
+ ),
+ )
+
+ def test_non_tensor_fields_preserved(self):
+ """Test that non-tensor fields are collected correctly"""
+ batch = [
+ {
+ "tokens": torch.tensor([1, 2]),
+ "labels": torch.tensor([3, 4]),
+ "metadata": "sample1",
+ },
+ {
+ "tokens": torch.tensor([5, 6, 7]),
+ "labels": torch.tensor([8, 9, 10]),
+ "metadata": "sample2",
+ },
+ ]
+ result = collate_padded(batch)
+
+ assert "metadata" in result
+ assert result["metadata"] == ["sample1", "sample2"]
+
+ def test_metrics_flattened(self):
+ """Test that metrics lists are flattened"""
+ batch = [
+ {
+ "tokens": torch.tensor([1, 2]),
+ "labels": torch.tensor([3, 4]),
+ "metrics": [
+ type("Metric", (), {"key": "loss", "value": 1.0})(),
+ type("Metric", (), {"key": "acc", "value": 0.9})(),
+ ],
+ },
+ {
+ "tokens": torch.tensor([5, 6, 7]),
+ "labels": torch.tensor([8, 9, 10]),
+ "metrics": [type("Metric", (), {"key": "loss", "value": 2.0})()],
+ },
+ ]
+ result = collate_padded(batch)
+
+ assert "metrics" in result
+ # Should be flattened from [[metric1, metric2], [metric3]] to [metric1, metric2, metric3]
+ assert len(result["metrics"]) == 3
+
+ def test_different_keys_error(self):
+ """Test that different keys across samples raises ValueError"""
+ batch = [
+ {"tokens": torch.tensor([1, 2]), "labels": torch.tensor([3, 4])},
+ {"tokens": torch.tensor([5, 6]), "other_key": torch.tensor([7, 8])},
+ ]
+
+ with pytest.raises(ValueError, match="All samples must have the same keys"):
+ collate_padded(batch)
+
+ def test_generic_tensor_handling(self):
+ """Test that any tensor field gets padded correctly"""
+ batch = [
+ {
+ "tokens": torch.tensor([1, 2]),
+ "labels": torch.tensor([3, 4]),
+ "custom_tensor": torch.tensor([100, 200, 300]),
+ },
+ {
+ "tokens": torch.tensor([5, 6, 7, 8]),
+ "labels": torch.tensor([9, 10, 11, 12]),
+ "custom_tensor": torch.tensor([400]),
+ },
+ ]
+ result = collate_padded(batch)
+
+ # Tokens padded to length 4
+ assert result["tokens"].shape == (2, 4)
+ torch.testing.assert_close(
+ result["tokens"], torch.tensor([[1, 2, 0, 0], [5, 6, 7, 8]])
+ )
+
+ # Labels padded to length 4 with CROSS_ENTROPY_IGNORE_IDX
+ assert result["labels"].shape == (2, 4)
+ torch.testing.assert_close(
+ result["labels"],
+ torch.tensor(
+ [
+ [3, 4, CROSS_ENTROPY_IGNORE_IDX, CROSS_ENTROPY_IGNORE_IDX],
+ [9, 10, 11, 12],
+ ]
+ ),
+ )
+
+ # Custom tensor padded to length 3 with 0
+ assert result["custom_tensor"].shape == (2, 3)
+ torch.testing.assert_close(
+ result["custom_tensor"], torch.tensor([[100, 200, 300], [400, 0, 0]])
+ )
diff --git a/tests/unit_tests/datasets/test_stop_after_one_epoch.py b/tests/unit_tests/datasets/test_stop_after_one_epoch.py
new file mode 100644
index 000000000..45cf035af
--- /dev/null
+++ b/tests/unit_tests/datasets/test_stop_after_one_epoch.py
@@ -0,0 +1,184 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Tests for StopAfterOneEpoch iterator and extract_epoch_from_batch helper."""
+from pathlib import Path
+
+import pytest
+import torch
+import torch.distributed as dist
+from forge.data.datasets import HfIterableDataset
+
+from forge.data.utils import extract_epoch_from_batch, StopAfterOneEpoch
+from forge.observability.metrics import Metric, Reduce
+
+from tests.test_utils import gpu_test
+from torch.testing._internal.common_fsdp import FSDPTest
+from torchdata.stateful_dataloader import StatefulDataLoader
+
+
+def create_test_json_file(path: Path, num_samples: int) -> None:
+ """Create test data file with simple samples."""
+ with open(path, "w") as f:
+ for i in range(num_samples):
+ f.write(f'{{"id": {i}, "tokens": [{i}, {i + 1}]}}\n')
+
+
+def simple_collate(batch):
+ """Simple collate function that mimics collate_packed behavior.
+
+ Stacks tensors, extends metrics list, keeps other fields as lists.
+ """
+ collated = {}
+ for key in batch[0].keys():
+ if isinstance(batch[0][key], torch.Tensor):
+ collated[key] = torch.stack([sample[key] for sample in batch], dim=0)
+ elif key == "metrics":
+ # Extend all metrics into a single list
+ collated[key] = []
+ for sample in batch:
+ collated[key].extend(sample[key])
+ else:
+ collated[key] = [sample[key] for sample in batch]
+ return collated
+
+
+class TestExtractEpochFromBatch:
+ """Test extract_epoch_from_batch helper function."""
+
+ def test_extract_epoch_from_batch_success(self):
+ """Test extracting epoch from valid batch with metrics."""
+ batch = {
+ "tokens": torch.tensor([1, 2, 3]),
+ "metrics": [
+ Metric(key="dataset/test/num_epochs", value=2, reduction=Reduce.MAX),
+ Metric(
+ key="dataset/test/other_metric", value=42, reduction=Reduce.MEAN
+ ),
+ ],
+ }
+ epoch = extract_epoch_from_batch(batch)
+ assert epoch == 2
+
+ def test_extract_epoch_missing_metrics_field(self):
+ """Test error when batch has no 'metrics' field."""
+ batch = {"tokens": torch.tensor([1, 2, 3])}
+ with pytest.raises(ValueError, match="Batch missing 'metrics' field"):
+ extract_epoch_from_batch(batch)
+
+ def test_extract_epoch_no_num_epochs_metric(self):
+ """Test error when no num_epochs metric found."""
+ batch = {
+ "metrics": [
+ Metric(
+ key="dataset/test/other_metric", value=42, reduction=Reduce.MEAN
+ ),
+ ]
+ }
+ with pytest.raises(ValueError, match="No 'num_epochs' metric found"):
+ extract_epoch_from_batch(batch)
+
+
+class TestStopAfterOneEpochSingleProcess:
+ """Test StopAfterOneEpoch in single-process mode (no distributed)."""
+
+ def test_stop_after_one_epoch(self, tmp_path):
+ """Verify iterator stops after exactly one epoch completes."""
+ # Create small dataset (10 samples)
+ data_file = tmp_path / "data.json"
+ create_test_json_file(data_file, num_samples=10)
+
+ dataset = HfIterableDataset(
+ path="json",
+ data_files=str(data_file),
+ split="train",
+ shuffle_buffer_size=0,
+ num_shards_per_rank=1,
+ )
+
+ dataloader = StatefulDataLoader(
+ dataset, batch_size=2, collate_fn=simple_collate
+ )
+
+ # Wrap with StopAfterOneEpoch
+ batch_iter = StopAfterOneEpoch(
+ iter=iter(dataloader),
+ device=torch.device("cpu"),
+ dp_mesh=None,
+ )
+
+ # Collect all batches until StopIteration
+ batches = []
+ for batch in batch_iter:
+ batches.append(batch)
+ # Verify all batches are from epoch 0
+ epoch = extract_epoch_from_batch(batch)
+ assert epoch == 0, f"Expected epoch 0, got {epoch}"
+
+ # Should have consumed exactly one epoch (5 batches of size 2)
+ assert len(batches) == 5
+
+
+class TestStopAfterOneEpochDistributed(FSDPTest):
+ """Test StopAfterOneEpoch with distributed synchronization."""
+
+ @property
+ def world_size(self) -> int:
+ return 2
+
+ @gpu_test(gpu_count=2)
+ def test_epoch_sync_across_ranks(self):
+ """Verify all ranks stop when any rank detects epoch change."""
+ import shutil
+ import tempfile
+
+ rank = dist.get_rank()
+ temp_dir = tempfile.mkdtemp(prefix=f"stop_epoch_test_rank{rank}_")
+
+ try:
+ data_file = Path(temp_dir) / "data.json"
+ # Create dataset with 20 samples, split across 2 ranks (10 each)
+ create_test_json_file(data_file, num_samples=20)
+
+ dataset = HfIterableDataset(
+ path="json",
+ data_files=str(data_file),
+ split="train",
+ shuffle_buffer_size=0,
+ num_shards_per_rank=1,
+ )
+
+ dataloader = StatefulDataLoader(
+ dataset, batch_size=2, collate_fn=simple_collate
+ )
+
+ # Get DP process group (use global group for this test)
+ dp_mesh = dist.group.WORLD
+
+ batch_iter = StopAfterOneEpoch(
+ iter=iter(dataloader),
+ device=torch.device("cuda"),
+ dp_mesh=dp_mesh,
+ )
+
+ # Collect batches
+ batches = []
+ for batch in batch_iter:
+ batches.append(batch)
+ # All should be epoch 0
+ assert extract_epoch_from_batch(batch) == 0
+
+ # All ranks should have processed exactly one epoch
+ # Since dataset is split across ranks, each rank gets 10 samples = 5 batches
+ assert (
+ len(batches) == 5
+ ), f"Rank {rank} expected 5 batches, got {len(batches)}"
+
+ # Synchronize to ensure both ranks completed
+ dist.barrier()
+
+ finally:
+ shutil.rmtree(temp_dir)
diff --git a/src/forge/cli/__init__.py b/tests/unit_tests/examples/__init__.py
similarity index 100%
rename from src/forge/cli/__init__.py
rename to tests/unit_tests/examples/__init__.py
diff --git a/src/forge/envs/__init__.py b/tests/unit_tests/examples/gsm8k/__init__.py
similarity index 100%
rename from src/forge/envs/__init__.py
rename to tests/unit_tests/examples/gsm8k/__init__.py
diff --git a/tests/unit_tests/rl/test_math_reward.py b/tests/unit_tests/examples/gsm8k/test_math_reward.py
similarity index 99%
rename from tests/unit_tests/rl/test_math_reward.py
rename to tests/unit_tests/examples/gsm8k/test_math_reward.py
index 726b1173c..d5cddfb9c 100644
--- a/tests/unit_tests/rl/test_math_reward.py
+++ b/tests/unit_tests/examples/gsm8k/test_math_reward.py
@@ -6,7 +6,7 @@
import unittest
-from forge.data.rewards import MathReward
+from apps.grpo.grading import MathReward
class TestMathReward(unittest.TestCase):
diff --git a/tests/unit_tests/rl/test_thinking_reward.py b/tests/unit_tests/examples/gsm8k/test_thinking_reward.py
similarity index 93%
rename from tests/unit_tests/rl/test_thinking_reward.py
rename to tests/unit_tests/examples/gsm8k/test_thinking_reward.py
index b95823e9a..218fbd6f9 100644
--- a/tests/unit_tests/rl/test_thinking_reward.py
+++ b/tests/unit_tests/examples/gsm8k/test_thinking_reward.py
@@ -6,7 +6,7 @@
import unittest
-from forge.data.rewards import ThinkingReward
+from apps.grpo.grading import ThinkingReward
class TestThinkingReward(unittest.TestCase):
@@ -203,6 +203,19 @@ def test_call_very_long_thinking_block(self):
result = self.reward("prompt", f"{long_content}")
self.assertEqual(result, 1.0)
+ def test_custom_tag(self):
+ """Test that ThinkingReward uses the custom tag passed in."""
+ # Create reward with custom Japanese tag
+ custom_tag_reward = ThinkingReward(tag="思考")
+
+ # Response with custom tag should get full reward
+ result = custom_tag_reward("prompt", "<思考>This is my reasoning思考>")
+ self.assertEqual(result, 1.0)
+
+ # Response with default "think" tag should get no reward
+ result = custom_tag_reward("prompt", "This is my reasoning")
+ self.assertEqual(result, 0.0)
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/unit_tests/losses/test_grpo_loss.py b/tests/unit_tests/losses/test_grpo_loss.py
index 8ffa6291f..6c9371427 100644
--- a/tests/unit_tests/losses/test_grpo_loss.py
+++ b/tests/unit_tests/losses/test_grpo_loss.py
@@ -35,7 +35,6 @@ def sample_data(self):
return logprobs, ref_logprobs, advantages, padding_mask
@pytest.mark.timeout(10)
- @pytest.mark.asyncio
def test_forward_basic(self, loss_fn, sample_data):
"""Test basic forward pass."""
logprobs, ref_logprobs, advantages, padding_mask = sample_data
@@ -48,7 +47,6 @@ def test_forward_basic(self, loss_fn, sample_data):
assert not torch.isnan(loss)
@pytest.mark.timeout(10)
- @pytest.mark.asyncio
def test_output_shape(self, loss_fn):
"""Test output shape for different input sizes."""
for batch_size in [1, 3, 8]:
@@ -62,7 +60,6 @@ def test_output_shape(self, loss_fn):
assert loss.shape == torch.Size([])
@pytest.mark.timeout(10)
- @pytest.mark.asyncio
def test_gradient_flow(self, loss_fn, sample_data):
"""Test that gradients flow through logprobs."""
logprobs, ref_logprobs, advantages, padding_mask = sample_data
@@ -76,7 +73,6 @@ def test_gradient_flow(self, loss_fn, sample_data):
assert torch.isfinite(logprobs.grad).all()
@pytest.mark.timeout(10)
- @pytest.mark.asyncio
def test_no_gradient_to_ref_logprobs(self, loss_fn, sample_data):
"""Test that gradients don't flow to reference logprobs."""
logprobs, ref_logprobs, advantages, padding_mask = sample_data
@@ -89,7 +85,6 @@ def test_no_gradient_to_ref_logprobs(self, loss_fn, sample_data):
assert ref_logprobs.grad is not None
@pytest.mark.timeout(10)
- @pytest.mark.asyncio
def test_padding_mask_effect(self, loss_fn):
"""Test that padding mask correctly ignores padded tokens."""
batch_size, seq_len = 2, 4
@@ -111,7 +106,6 @@ def test_padding_mask_effect(self, loss_fn):
assert not torch.allclose(loss_full, loss_partial)
@pytest.mark.timeout(10)
- @pytest.mark.asyncio
def test_beta_parameter_effect(self, sample_data):
"""Test that different beta values produce different losses."""
logprobs, ref_logprobs, advantages, padding_mask = sample_data
@@ -128,7 +122,6 @@ def test_beta_parameter_effect(self, sample_data):
assert not torch.allclose(loss_1, loss_2, atol=1e-6)
@pytest.mark.timeout(10)
- @pytest.mark.asyncio
def test_zero_advantages(self, loss_fn):
"""Test behavior with zero advantages."""
batch_size, seq_len = 2, 4
@@ -144,7 +137,6 @@ def test_zero_advantages(self, loss_fn):
assert torch.isfinite(loss)
@pytest.mark.timeout(10)
- @pytest.mark.asyncio
def test_identical_policies(self, loss_fn):
"""Test behavior when current and reference policies are identical."""
batch_size, seq_len = 2, 4
@@ -160,7 +152,6 @@ def test_identical_policies(self, loss_fn):
assert torch.isfinite(loss)
@pytest.mark.timeout(10)
- @pytest.mark.asyncio
def test_extreme_values(self, loss_fn):
"""Test with extreme but valid values."""
batch_size, seq_len = 2, 3
@@ -179,7 +170,6 @@ def test_extreme_values(self, loss_fn):
assert not torch.isnan(loss)
@pytest.mark.timeout(10)
- @pytest.mark.asyncio
def test_numerical_stability(self, loss_fn):
"""Test numerical stability with edge cases."""
batch_size, seq_len = 1, 2
@@ -195,7 +185,6 @@ def test_numerical_stability(self, loss_fn):
assert torch.isfinite(loss)
@pytest.mark.timeout(10)
- @pytest.mark.asyncio
def test_all_masked_sequence(self, loss_fn):
"""Test behavior when entire sequence is masked."""
batch_size, seq_len = 1, 3
@@ -211,7 +200,6 @@ def test_all_masked_sequence(self, loss_fn):
assert torch.isfinite(loss)
@pytest.mark.timeout(10)
- @pytest.mark.asyncio
def test_mathematical_correctness(self, loss_fn):
"""Test mathematical correctness with simpler verification."""
# Test with known simple case
diff --git a/tests/unit_tests/data/__init__.py b/tests/unit_tests/observability/__init__.py
similarity index 100%
rename from tests/unit_tests/data/__init__.py
rename to tests/unit_tests/observability/__init__.py
diff --git a/tests/unit_tests/observability/conftest.py b/tests/unit_tests/observability/conftest.py
new file mode 100644
index 000000000..627eaca77
--- /dev/null
+++ b/tests/unit_tests/observability/conftest.py
@@ -0,0 +1,70 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Shared fixtures and mocks for observability unit tests."""
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+from forge.observability.metrics import MetricCollector
+
+
+@pytest.fixture(autouse=True)
+def clear_metric_collector_singletons():
+ """Clear MetricCollector singletons before each test to avoid state leakage."""
+ MetricCollector._instances.clear()
+ yield
+ MetricCollector._instances.clear()
+
+
+@pytest.fixture(autouse=True)
+def clean_metrics_environment():
+ """Override the global mock_metrics_globally fixture to allow real metrics testing."""
+ import os
+
+ from forge.env import FORGE_DISABLE_METRICS
+
+ # Set default state for tests (metrics enabled)
+ if FORGE_DISABLE_METRICS.name in os.environ:
+ del os.environ[FORGE_DISABLE_METRICS.name]
+
+ yield
+
+
+@pytest.fixture
+def mock_rank():
+ """Mock current_rank function with configurable rank."""
+ with patch("forge.observability.metrics.current_rank") as mock:
+ rank_obj = MagicMock()
+ rank_obj.rank = 0
+ mock.return_value = rank_obj
+ yield mock
+
+
+@pytest.fixture
+def mock_actor_context():
+ """Mock Monarch actor context for testing actor name generation."""
+ with (
+ patch("forge.observability.metrics.context") as mock_context,
+ patch("forge.observability.metrics.current_rank") as mock_rank,
+ ):
+ # Setup mock context
+ ctx = MagicMock()
+ actor_instance = MagicMock()
+ actor_instance.actor_id = "_1rjutFUXQrEJ[0].TestActorConfigured[0]"
+ ctx.actor_instance = actor_instance
+ mock_context.return_value = ctx
+
+ # Setup mock rank
+ rank_obj = MagicMock()
+ rank_obj.rank = 0
+ mock_rank.return_value = rank_obj
+
+ yield {
+ "context": mock_context,
+ "rank": mock_rank,
+ "expected_name": "TestActor_0XQr_r0",
+ }
diff --git a/tests/unit_tests/observability/test_metric_actors.py b/tests/unit_tests/observability/test_metric_actors.py
new file mode 100644
index 000000000..fd3c96687
--- /dev/null
+++ b/tests/unit_tests/observability/test_metric_actors.py
@@ -0,0 +1,179 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Optimized unit tests for metric actors functionality."""
+
+from unittest.mock import patch
+
+import pytest
+
+from forge.observability.metric_actors import (
+ get_or_create_metric_logger,
+ GlobalLoggingActor,
+ LocalFetcherActor,
+)
+
+from forge.observability.metrics import LoggingMode
+from monarch.actor import this_host
+
+
+@pytest.fixture
+def global_logger():
+ """Create a GlobalLoggingActor for testing."""
+ p = this_host().spawn_procs(per_host={"cpus": 1})
+ return p.spawn("TestGlobalLogger", GlobalLoggingActor)
+
+
+@pytest.fixture
+def local_fetcher(global_logger):
+ """Create a LocalFetcherActor linked to global logger."""
+ p = this_host().spawn_procs(per_host={"cpus": 1})
+ return p.spawn("TestLocalFetcher", LocalFetcherActor, global_logger)
+
+
+class TestBasicOperations:
+ """Test basic operations for actors."""
+
+ @pytest.mark.asyncio
+ async def test_local_fetcher_flush(self, local_fetcher):
+ """Test LocalFetcherActor flush operations."""
+ result_with_state = await local_fetcher.flush.call_one(
+ global_step=1, return_state=True
+ )
+ assert result_with_state == {}
+
+ result_without_state = await local_fetcher.flush.call_one(
+ global_step=1, return_state=False
+ )
+ assert result_without_state == {}
+
+ @pytest.mark.asyncio
+ async def test_global_logger_basic_ops(self, global_logger):
+ """Test GlobalLoggingActor basic operations."""
+ count = await global_logger.get_fetcher_count.call_one()
+ assert count >= 0
+
+ has_fetcher = await global_logger.has_fetcher.call_one("nonexistent")
+ assert has_fetcher is False
+
+ # Global logger flush (should not raise error)
+ await global_logger.flush.call_one(global_step=1)
+
+ @pytest.mark.asyncio
+ async def test_backend_init(self, local_fetcher):
+ """Test backend initialization and shutdown."""
+ metadata = {"wandb": {"shared_run_id": "test123"}}
+ config = {"console": {"logging_mode": LoggingMode.PER_RANK_REDUCE}}
+
+ await local_fetcher.init_backends.call_one(metadata, config, global_step=5)
+ await local_fetcher.shutdown.call_one()
+
+
+class TestRegistrationLifecycle:
+ """Test registration lifecycle."""
+
+ @pytest.mark.timeout(10)
+ @pytest.mark.asyncio
+ async def test_registration_lifecycle(self, global_logger, local_fetcher):
+ """Test complete registration/deregistration lifecycle."""
+ proc_name = "lifecycle_test_proc"
+
+ # Initial state
+ initial_count = await global_logger.get_fetcher_count.call_one()
+ assert await global_logger.has_fetcher.call_one(proc_name) is False
+
+ # Register
+ await global_logger.register_fetcher.call_one(local_fetcher, proc_name)
+
+ # Verify registered
+ new_count = await global_logger.get_fetcher_count.call_one()
+ assert new_count == initial_count + 1
+ assert await global_logger.has_fetcher.call_one(proc_name) is True
+
+ # Deregister
+ await global_logger.deregister_fetcher.call_one(proc_name)
+
+ # Verify deregistered
+ final_count = await global_logger.get_fetcher_count.call_one()
+ assert final_count == initial_count
+ assert await global_logger.has_fetcher.call_one(proc_name) is False
+
+
+class TestBackendConfiguration:
+ """Test backend configuration validation."""
+
+ @pytest.mark.timeout(3)
+ @pytest.mark.asyncio
+ async def test_valid_backend_configs(self, global_logger):
+ """Test valid backend configurations."""
+ # Empty config
+ await global_logger.init_backends.call_one({})
+
+ # Valid configs for different logging_mode modes
+ for logging_mode in [LoggingMode.GLOBAL_REDUCE, LoggingMode.PER_RANK_NO_REDUCE]:
+ config = {"console": {"logging_mode": logging_mode}}
+ await global_logger.init_backends.call_one(config)
+
+ def test_invalid_backend_configs(self):
+ """Test invalid backend configurations and warnings using direct validation."""
+ actor = GlobalLoggingActor()
+
+ # Test 1: Invalid logging_mode should raise ValueError
+ with pytest.raises(ValueError, match="is not a valid LoggingMode"):
+ actor._validate_backend_config("console", {"logging_mode": "invalid_mode"})
+
+ # Test 2: WandB PER_RANK_REDUCE + per_rank_share_run=True should warn
+ with patch("forge.observability.metric_actors.logger.warning") as mock_warn:
+ config = {
+ "logging_mode": "per_rank_reduce",
+ "per_rank_share_run": True,
+ "project": "test_project",
+ }
+
+ result = actor._validate_backend_config("wandb", config)
+
+ # Should have logged warning about suboptimal config
+ mock_warn.assert_called_once()
+ warning_msg = str(mock_warn.call_args)
+ assert "not recommended" in warning_msg
+
+ # Should still return valid config with LoggingMode enum
+ assert result["logging_mode"] == LoggingMode.PER_RANK_REDUCE
+ assert result["per_rank_share_run"] is True
+ assert result["project"] == "test_project"
+
+
+class TestErrorHandling:
+ """Test error handling scenarios."""
+
+ @pytest.mark.timeout(3)
+ @pytest.mark.asyncio
+ async def test_deregister_nonexistent_fetcher(self, global_logger):
+ """Test deregistering non-existent fetcher doesn't crash."""
+ await global_logger.deregister_fetcher.call_one("nonexistent_proc")
+
+ @pytest.mark.timeout(3)
+ @pytest.mark.asyncio
+ async def test_shutdown(self, global_logger):
+ """Test shutdown without issues."""
+ await global_logger.shutdown.call_one()
+
+
+class TestGetOrCreateMetricLogger:
+ """Test the integration function."""
+
+ @pytest.mark.timeout(3)
+ @pytest.mark.asyncio
+ async def test_get_or_create_functionality(self):
+ """Test get_or_create_metric_logger basic functionality."""
+ result = await get_or_create_metric_logger(process_name="TestController")
+
+ # Should return a GlobalLoggingActor mesh
+ assert result is not None
+
+ # Should be able to call basic methods
+ count = await result.get_fetcher_count.call_one()
+ assert count >= 0
diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py
new file mode 100644
index 000000000..63f7046db
--- /dev/null
+++ b/tests/unit_tests/observability/test_metrics.py
@@ -0,0 +1,493 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Unit tests for core metrics functionality."""
+
+import time
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from forge.observability.metric_actors import get_or_create_metric_logger
+from forge.observability.metrics import (
+ BackendRole,
+ ConsoleBackend,
+ get_logger_backend_class,
+ LoggingMode,
+ MaxAccumulator,
+ MeanAccumulator,
+ Metric,
+ MetricCollector,
+ MinAccumulator,
+ record_metric,
+ Reduce,
+ reduce_metrics_states,
+ StdAccumulator,
+ SumAccumulator,
+ WandbBackend,
+)
+
+
+class TestMetricCreation:
+ """Test Metric object creation and record_metric function - Diff 2 features."""
+
+ def test_metric_creation_automatic_timestamp(self, mock_rank):
+ """Test Metric object creation with automatic timestamp."""
+ before_time = time.time()
+ metric = Metric("test_key", 42.0, Reduce.MEAN)
+ after_time = time.time()
+
+ assert metric.key == "test_key"
+ assert metric.value == 42.0
+ assert metric.reduction == Reduce.MEAN
+ assert metric.timestamp is not None
+ assert before_time <= metric.timestamp <= after_time
+
+ def test_metric_creation_custom_timestamp(self, mock_rank):
+ """Test Metric object creation with custom timestamp."""
+ custom_time = 1234567890.0
+ metric = Metric("test_key2", 24.0, Reduce.SUM, timestamp=custom_time)
+ assert metric.timestamp == custom_time
+
+ def test_record_metric(self, mock_rank):
+ """Test record_metric creates correct Metric and calls collector."""
+ # Mock the MetricCollector constructor to return a mock instance
+ mock_collector = MagicMock()
+
+ with patch(
+ "forge.observability.metrics.MetricCollector", return_value=mock_collector
+ ):
+ record_metric("loss", 1.5, Reduce.MEAN)
+
+ # Verify push was called on the mock collector
+ mock_collector.push.assert_called_once()
+
+ # Verify the metric passed to push
+ pushed_metric = mock_collector.push.call_args[0][0]
+ assert pushed_metric.key == "loss"
+ assert pushed_metric.value == 1.5
+ assert pushed_metric.reduction == Reduce.MEAN
+
+ def test_new_enums_and_constants(self):
+ """Test BackendRole constants and usage."""
+ # Test BackendRole enum values
+ assert BackendRole.LOCAL.value == "local"
+ assert BackendRole.GLOBAL.value == "global"
+
+ # Test that BackendRole is a proper Enum
+ assert isinstance(BackendRole.LOCAL, BackendRole)
+ assert isinstance(BackendRole.GLOBAL, BackendRole)
+
+ @pytest.mark.asyncio
+ async def test_backend_role_usage(self):
+ """Test that BackendRole constants are actually used instead of string literals."""
+ # Test ConsoleBackend
+ console_backend = ConsoleBackend(logging_mode=LoggingMode.GLOBAL_REDUCE)
+ await console_backend.init(role=BackendRole.LOCAL)
+
+ # Test WandbBackend role validation without WandB initialization
+ wandb_backend = WandbBackend(
+ logging_mode=LoggingMode.GLOBAL_REDUCE, project="test"
+ )
+
+ # Mock all the WandB init methods to focus only on role validation
+ with (
+ patch.object(wandb_backend, "_init_global"),
+ patch.object(wandb_backend, "_init_shared_global"),
+ patch.object(wandb_backend, "_init_shared_local"),
+ patch.object(wandb_backend, "_init_per_rank"),
+ ):
+ # Should not raise error for valid roles (type system prevents invalid values)
+ await wandb_backend.init(role=BackendRole.GLOBAL)
+ await wandb_backend.init(role=BackendRole.LOCAL)
+
+
+class TestReduceOperations:
+ """Test reduce_metrics_states function returning List[Metric] - Diff 2 feature."""
+
+ def test_empty_states(self):
+ """Test reduce_metrics_states with empty input."""
+ result = reduce_metrics_states([])
+ assert result == []
+
+ def test_single_state(self):
+ """Test reduce_metrics_states with single state."""
+ states = [
+ {
+ "loss": {"reduction_type": "mean", "sum": 10.0, "count": 2},
+ "rollout/sample": {
+ "reduction_type": "sample",
+ "samples": [{"id": 1, "reward": 0.5}],
+ },
+ }
+ ]
+ metrics = reduce_metrics_states(states)
+ assert len(metrics) == 2
+ # Convert to dict for easier testing
+ result_dict = {m.key: (m.value, m.reduction) for m in metrics}
+
+ assert result_dict["loss"][0] == 5.0
+ assert result_dict["loss"][1] == Reduce.MEAN
+
+ assert result_dict["rollout/sample"][0] == [{"id": 1, "reward": 0.5}]
+ assert result_dict["rollout/sample"][1] == Reduce.SAMPLE
+
+ def test_multiple_states(self):
+ """Test reduce_metrics_states with multiple states."""
+ states = [
+ {
+ "loss": {"reduction_type": "mean", "sum": 10.0, "count": 2},
+ "rollout/sample": {
+ "reduction_type": "sample",
+ "samples": [{"id": 1, "reward": 0.5}],
+ },
+ },
+ {
+ "loss": {"reduction_type": "mean", "sum": 20.0, "count": 3},
+ "rollout/sample": {
+ "reduction_type": "sample",
+ "samples": [{"id": 2, "reward": 0.8}],
+ },
+ },
+ {"accuracy": {"reduction_type": "sum", "total": 15.0}},
+ ]
+ metrics = reduce_metrics_states(states)
+
+ assert len(metrics) == 3
+
+ # Convert to dict for easier testing
+ result_dict = {m.key: (m.value, m.reduction) for m in metrics}
+
+ # Check scalar reductions
+ assert result_dict["loss"][0] == 30.0 / 5.0 # 6.0
+ assert result_dict["loss"][1] == Reduce.MEAN
+ assert result_dict["accuracy"][0] == 15.0
+ assert result_dict["accuracy"][1] == Reduce.SUM
+
+ # Check sample concatenation
+ assert result_dict["rollout/sample"][0] == [
+ {"id": 1, "reward": 0.5},
+ {"id": 2, "reward": 0.8},
+ ]
+ assert result_dict["rollout/sample"][1] == Reduce.SAMPLE
+
+ def test_mismatched_reduction_types_raises_error(self):
+ """Test reduce_metrics_states raises error for mismatched reduction types."""
+ states = [
+ {"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}},
+ {"loss": {"reduction_type": "sum", "total": 20.0}},
+ ]
+ with pytest.raises(ValueError, match="Mismatched reduction types"):
+ reduce_metrics_states(states)
+
+
+class TestAccumulators:
+ """Test all accumulator classes and their operations - Diff 2 extensions."""
+
+ def test_sum_accumulator(self):
+ """Test SumAccumulator operations."""
+ acc = SumAccumulator(Reduce.SUM)
+
+ acc.append(5.0)
+ acc.append(3.0)
+ assert acc.get_value() == 8.0
+
+ state = acc.get_state()
+ assert state["total"] == 8.0
+ assert state["reduction_type"] == "sum"
+
+ acc.reset()
+ assert acc.get_value() == 0.0
+
+ def test_max_accumulator(self):
+ """Test MaxAccumulator operations."""
+ acc = MaxAccumulator(Reduce.MAX)
+
+ acc.append(5.0)
+ acc.append(10.0)
+ acc.append(3.0)
+ assert acc.get_value() == 10.0
+
+ state = acc.get_state()
+ assert state["max_val"] == 10.0
+ assert state["reduction_type"] == "max"
+
+ def test_min_accumulator(self):
+ """Test MinAccumulator operations."""
+ acc = MinAccumulator(Reduce.MIN)
+
+ acc.append(5.0)
+ acc.append(10.0)
+ acc.append(3.0)
+ assert acc.get_value() == 3.0
+
+ state = acc.get_state()
+ assert state["min_val"] == 3.0
+ assert state["reduction_type"] == "min"
+
+ def test_std_accumulator(self):
+ """Test StdAccumulator operations."""
+ acc = StdAccumulator(Reduce.STD)
+
+ # Test with zero/one values
+ assert acc.get_value() == 0.0
+ acc.append(5.0)
+ assert acc.get_value() == 0.0 # std of single value is 0
+
+ # Test with multiple values
+ acc.append(7.0) # values: 5, 7, mean=6, std=1
+ assert abs(acc.get_value() - 1.0) < 0.001
+
+ state = acc.get_state()
+ assert state["sum"] == 12.0
+ assert state["sum_sq"] == 74.0 # 5^2 + 7^2 = 25 + 49 = 74
+ assert state["count"] == 2
+
+ @pytest.mark.parametrize(
+ "accumulator_class,states,expected",
+ [
+ (
+ MeanAccumulator,
+ [
+ {"reduction_type": "mean", "sum": 10.0, "count": 2},
+ {"reduction_type": "mean", "sum": 20.0, "count": 3},
+ ],
+ 6.0, # (10+20) / (2+3)
+ ),
+ (
+ SumAccumulator,
+ [
+ {"reduction_type": "sum", "total": 10.0},
+ {"reduction_type": "sum", "total": 15.0},
+ ],
+ 25.0,
+ ),
+ ],
+ )
+ def test_accumulator_state_reduction(self, accumulator_class, states, expected):
+ """Test cross-accumulator state reduction."""
+ result = accumulator_class.get_reduced_value_from_states(states)
+ assert result == expected
+
+ def test_reduce_enum_accumulator_mapping(self):
+ """Test that Reduce enum correctly maps to accumulator classes."""
+ assert Reduce.MEAN.accumulator_class == MeanAccumulator
+ assert Reduce.SUM.accumulator_class == SumAccumulator
+ assert Reduce.MAX.accumulator_class == MaxAccumulator
+ assert Reduce.MIN.accumulator_class == MinAccumulator
+ assert Reduce.STD.accumulator_class == StdAccumulator
+
+
+class TestCriticalFixes:
+ """Test critical production fixes from Diff 1."""
+
+ def test_uninitialized_push_logs_warning(self, mock_rank, caplog):
+ """Test MetricCollector.push() logs warning when uninitialized."""
+ collector = MetricCollector()
+ metric = Metric("test", 1.0, Reduce.MEAN)
+
+ # Should not raise error, just log warning and return
+ collector.push(metric)
+ assert any(
+ "Metric logging backends" in record.message for record in caplog.records
+ )
+
+ @pytest.mark.asyncio
+ async def test_uninitialized_flush_logs_warning(self, mock_rank, caplog):
+ """Test MetricCollector.flush() logs warning when uninitialized."""
+ collector = MetricCollector()
+
+ # Should not raise error, just log warning and return empty dict
+ result = await collector.flush(global_step=1, return_state=True)
+ assert result == {}
+ assert any(
+ "Cannot flush collected metrics" in record.message
+ for record in caplog.records
+ )
+
+ @patch.dict("os.environ", {"FORGE_DISABLE_METRICS": "true"})
+ @patch("forge.observability.metrics.MetricCollector")
+ def test_record_metric_disabled(self, mock_collector_class):
+ """Test record_metric is no-op when FORGE_DISABLE_METRICS=true."""
+ record_metric("loss", 1.5, Reduce.MEAN)
+ mock_collector_class.assert_not_called()
+
+ @patch.dict("os.environ", {"FORGE_DISABLE_METRICS": "false"})
+ @patch("forge.observability.metrics.MetricCollector")
+ def test_record_metric_enabled_explicit(self, mock_collector_class, mock_rank):
+ """Test record_metric works when FORGE_DISABLE_METRICS=false."""
+ mock_collector = MagicMock()
+ mock_collector_class.return_value = mock_collector
+
+ record_metric("loss", 1.5, Reduce.MEAN)
+ mock_collector_class.assert_called_once()
+ mock_collector.push.assert_called_once()
+
+ def test_wandb_backend_creation(self):
+ """Test WandbBackend creation and basic setup without WandB dependency."""
+
+ backend = WandbBackend(
+ logging_mode=LoggingMode.GLOBAL_REDUCE,
+ project="test_project",
+ group="test_group",
+ )
+
+ # Test backend kwargs storage
+ assert backend.backend_kwargs["project"] == "test_project"
+ assert backend.backend_kwargs["group"] == "test_group"
+ assert backend.logging_mode == LoggingMode.GLOBAL_REDUCE
+ assert backend.per_rank_share_run is False # default
+
+ # Test metadata method
+ metadata = backend.get_metadata_for_secondary_ranks()
+ assert metadata == {} # Should be empty when no run
+
+ @pytest.mark.asyncio
+ async def test_console_backend(self):
+ """Test ConsoleBackend basic operations."""
+ backend = ConsoleBackend(logging_mode=LoggingMode.GLOBAL_REDUCE)
+
+ await backend.init(role=BackendRole.LOCAL)
+
+ # Test log_batch - should not raise
+ # Create a test metric
+ test_metric = Metric("test", 1.0, Reduce.MEAN)
+ await backend.log_batch([test_metric], global_step=1)
+
+ await backend.finish() # Should not raise
+
+
+class TestBasicAccumulators:
+ """Test basic accumulator functionality."""
+
+ def test_mean_accumulator(self):
+ """Test MeanAccumulator operations."""
+ acc = MeanAccumulator(Reduce.MEAN)
+
+ # Test initial state
+ assert acc.get_value() == 0.0
+ state = acc.get_state()
+ assert state["sum"] == 0.0
+ assert state["count"] == 0
+
+ # Test append and get_value
+ acc.append(10.0)
+ acc.append(20.0)
+ assert acc.get_value() == 15.0
+
+ # Test state
+ state = acc.get_state()
+ assert state["sum"] == 30.0
+ assert state["count"] == 2
+ assert state["reduction_type"] == "mean"
+
+ # Test reset
+ acc.reset()
+ assert acc.get_value() == 0.0
+ assert acc.get_state()["sum"] == 0.0
+ assert acc.get_state()["count"] == 0
+
+ def test_reduce_enum_accumulator_mapping(self):
+ """Test that Reduce enum correctly maps to accumulator classes."""
+ assert Reduce.MEAN.accumulator_class == MeanAccumulator
+
+
+class TestBackendFactory:
+ """Test backend factory function."""
+
+ def test_backend_factory(self):
+ """Test get_logger_backend_class factory function."""
+ assert get_logger_backend_class("console") == ConsoleBackend
+ assert get_logger_backend_class("wandb") == WandbBackend
+
+ with pytest.raises(ValueError, match="Unknown logger backend type"):
+ get_logger_backend_class("invalid_backend")
+
+
+class TestMetricCollector:
+ """Test MetricCollector singleton behavior."""
+
+ def test_singleton_per_rank(self, mock_rank):
+ """Test MetricCollector singleton behavior per rank."""
+ mock_rank.return_value.rank = 0
+ collector1 = MetricCollector()
+ collector2 = MetricCollector()
+ assert collector1 is collector2
+
+ # Different rank should get different instance
+ mock_rank.return_value.rank = 1
+ collector3 = MetricCollector()
+ assert collector1 is not collector3
+
+
+class TestMetricActorDisabling:
+ """Test environment flag to disable metric actors."""
+
+ async def _test_fetcher_registration(self, env_var_value, should_register_fetchers):
+ """Check if FORGE_DISABLE_METRICS=[True, False, None] correctly disables fetcher registration.
+
+ Args:
+ env_var_value: Value to set for FORGE_DISABLE_METRICS (None means unset)
+ should_register_fetchers: Whether fetchers should be registered (True) or not (False)
+ """
+ import os
+
+ import forge.observability.metric_actors
+ from forge.env import FORGE_DISABLE_METRICS
+ from monarch.actor import this_host
+
+ # set fresh env
+ # Note: Environment variable setup is handled by clean_metrics_environment fixture
+ forge.observability.metric_actors._global_logger = None
+
+ if env_var_value is not None:
+ os.environ[FORGE_DISABLE_METRICS.name] = env_var_value
+
+ procs = this_host().spawn_procs(per_host={"cpus": 1})
+
+ if hasattr(procs, "_local_fetcher"):
+ delattr(procs, "_local_fetcher")
+
+ # Test functionality - pass explicit process_name since test bypasses provisioner
+ global_logger = await get_or_create_metric_logger(
+ proc_mesh=procs, process_name="TestProcess"
+ )
+
+ # Get results to check
+ proc_has_fetcher = hasattr(procs, "_local_fetcher")
+ proc_id = procs._uid if hasattr(procs, "_uid") else None
+ global_has_fetcher = (
+ await global_logger.has_fetcher.call_one(proc_id) if proc_id else False
+ )
+
+ # Assert based on expected behavior
+ if should_register_fetchers:
+ assert (
+ proc_has_fetcher
+ ), f"Expected process to have _local_fetcher when FORGE_DISABLE_METRICS={env_var_value}"
+ assert (
+ global_has_fetcher
+ ), f"Expected global logger to have fetcher registered when FORGE_DISABLE_METRICS={env_var_value}"
+ else:
+ assert (
+ not proc_has_fetcher
+ ), f"Expected process to NOT have _local_fetcher when FORGE_DISABLE_METRICS={env_var_value}"
+ assert (
+ not global_has_fetcher
+ ), f"Expected global logger to NOT have fetcher registered when FORGE_DISABLE_METRICS={env_var_value}"
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "env_value,should_register",
+ [
+ ("false", True),
+ ("true", False),
+ (None, True),
+ ],
+ )
+ async def test_fetcher_registration_with_env_flag(self, env_value, should_register):
+ """Test fetcher registration behavior with different environment flag values."""
+ await self._test_fetcher_registration(env_value, should_register)
diff --git a/tests/unit_tests/observability/test_perf_tracker.py b/tests/unit_tests/observability/test_perf_tracker.py
index 6af7331f1..dd75d4540 100644
--- a/tests/unit_tests/observability/test_perf_tracker.py
+++ b/tests/unit_tests/observability/test_perf_tracker.py
@@ -7,12 +7,12 @@
import asyncio
import time
from contextlib import contextmanager
-from typing import List, Literal, Tuple, Union
+from typing import Literal, Union
from unittest.mock import Mock, patch
import pytest
import torch
-from forge.env_constants import DISABLE_PERF_METRICS, METRIC_TIMER_USES_CUDA
+from forge.env import DISABLE_PERF_METRICS, METRIC_TIMER_USES_GPU
from forge.observability.metrics import Reduce
from forge.observability.perf_tracker import _TimerCPU, _TimerCUDA, trace, Tracer
@@ -21,7 +21,7 @@
@pytest.fixture
def mock_record_metric_calls(monkeypatch):
"""Mock record_metric that tracks all calls."""
- calls: List[Tuple[str, float, Reduce]] = []
+ calls: list[tuple[str, float, Reduce]] = []
def mock_record_metric(name, val, red):
calls.append((name, val, red))
@@ -55,7 +55,8 @@ def assert_metrics_dict_matches(calls, expected_metrics):
assert metric_name in actual_metrics, f"Missing metric: {metric_name}"
actual_val = actual_metrics[metric_name]
assert actual_val == pytest.approx(
- expected_val, rel=0.1 # 10% relative tolerance for timing tests
+ expected_val,
+ rel=0.2, # 20% relative tolerance for timing tests
), f"Expected {metric_name}={expected_val}, got {actual_val}"
@@ -135,7 +136,7 @@ def test_comprehensive_workflow(
if timer == "gpu" and not torch.cuda.is_available():
pytest.skip("CUDA not available")
- monkeypatch.setenv(METRIC_TIMER_USES_CUDA, str(timer == "gpu"))
+ monkeypatch.setenv(METRIC_TIMER_USES_GPU.name, str(timer == "gpu"))
async def run_concurrent_tasks():
start_time = time.perf_counter()
@@ -162,6 +163,7 @@ async def run_concurrent_tasks():
if timer == "gpu" and torch.cuda.is_available():
assert isinstance(tracer._timer, _TimerCUDA), "Expected CUDA timer"
else:
+ value = METRIC_TIMER_USES_GPU.get_value()
assert isinstance(tracer._timer, _TimerCPU), "Expected CPU timer"
tracer.step("backend_check")
tracer.stop()
@@ -309,29 +311,39 @@ def test_tracer_and_timer_reuse(self, mock_record_metric_calls):
cpu_timer.start()
time.sleep(0.005)
cpu_timer.step("cpu_step1")
- durations1 = cpu_timer.get_all_durations()
+ cpu_durations_list1, cpu_final_ms1 = cpu_timer.get_all_durations()
cpu_timer.start()
time.sleep(0.005)
cpu_timer.step("cpu_step2")
- durations2 = cpu_timer.get_all_durations()
+ cpu_durations_list2, cpu_final_ms2 = cpu_timer.get_all_durations()
- assert len(durations1) == 1 and durations1[0][0] == "cpu_step1"
- assert len(durations2) == 1 and durations2[0][0] == "cpu_step2"
+ assert (
+ len(cpu_durations_list1) == 1 and cpu_durations_list1[0][0] == "cpu_step1"
+ )
+ assert (
+ len(cpu_durations_list2) == 1 and cpu_durations_list2[0][0] == "cpu_step2"
+ )
# Test CUDA timer reuse (if available)
if torch.cuda.is_available():
cuda_timer = _TimerCUDA()
cuda_timer.start()
cuda_timer.step("cuda_step1")
- cuda_durations1 = cuda_timer.get_all_durations()
+ cuda_durations_list1, cuda_final_ms1 = cuda_timer.get_all_durations()
cuda_timer.start()
cuda_timer.step("cuda_step2")
- cuda_durations2 = cuda_timer.get_all_durations()
+ cuda_durations_list2, cuda_final_ms2 = cuda_timer.get_all_durations()
- assert len(cuda_durations1) == 1 and cuda_durations1[0][0] == "cuda_step1"
- assert len(cuda_durations2) == 1 and cuda_durations2[0][0] == "cuda_step2"
+ assert (
+ len(cuda_durations_list1) == 1
+ and cuda_durations_list1[0][0] == "cuda_step1"
+ )
+ assert (
+ len(cuda_durations_list2) == 1
+ and cuda_durations_list2[0][0] == "cuda_step2"
+ )
def test_exception_handling_context_manager(self, mock_record_metric_calls):
"""Test context manager properly cleans up on exception."""
@@ -354,7 +366,7 @@ def test_disable_perf_metrics_all_modes(
self, mode, monkeypatch, mock_record_metric_calls
):
"""Test DISABLE_PERF_METRICS disables all modes."""
- monkeypatch.setenv(DISABLE_PERF_METRICS, "true")
+ monkeypatch.setenv(DISABLE_PERF_METRICS.name, "true")
async def disabled_workflow():
return await TracingModes.run_workflow(mode, f"disabled_{mode}")
@@ -370,17 +382,18 @@ async def disabled_workflow():
("false", _TimerCPU),
],
)
- def test_metric_timer_uses_cuda_override(
+ def test_metric_timer_uses_gpu_override(
self, env_value, expected_backend, monkeypatch
):
- """Test METRIC_TIMER_USES_CUDA env var overrides timer parameter."""
+ """Test METRIC_TIMER_USES_GPU env var overrides timer parameter."""
if env_value == "true" and not torch.cuda.is_available():
pytest.skip("CUDA not available")
- with patch("torch.cuda.is_available", return_value=True), patch(
- "forge.observability.perf_tracker.record_metric"
+ with (
+ patch("torch.cuda.is_available", return_value=True),
+ patch("forge.observability.perf_tracker.record_metric"),
):
- monkeypatch.setenv(METRIC_TIMER_USES_CUDA, env_value)
+ monkeypatch.setenv(METRIC_TIMER_USES_GPU.name, env_value)
# Test with timer="cpu" (should be overridden by env)
tracer = Tracer("env_test", timer="cpu")
diff --git a/tests/unit_tests/observability/test_utils.py b/tests/unit_tests/observability/test_utils.py
new file mode 100644
index 000000000..6b173cc42
--- /dev/null
+++ b/tests/unit_tests/observability/test_utils.py
@@ -0,0 +1,54 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Tests for observability utility functions."""
+
+from forge.controller.actor import ForgeActor
+
+from forge.observability.utils import get_proc_name_with_rank
+from monarch.actor import endpoint
+
+
+class UtilActor(ForgeActor):
+ """Actor for testing get_proc_name_with_rank in spawned context."""
+
+ @endpoint
+ async def get_name(self) -> str:
+ return get_proc_name_with_rank()
+
+
+class TestGetProcNameWithRank:
+ """Tests for get_proc_name_with_rank utility."""
+
+ def test_direct_proc(self):
+ """Direct proc should return 'client_r0'."""
+ assert get_proc_name_with_rank() == "client_r0"
+
+ def test_direct_proc_with_override(self):
+ """Direct proc with override should use provided name."""
+ result = get_proc_name_with_rank(proc_name="MyProcess")
+ assert result == "MyProcess_r0"
+
+ # TODO (felipemello): currently not working with CI wheel, but passes locally
+ # reactive once wheel is updated with new monarch version
+ # @pytest.mark.timeout(10)
+ # @pytest.mark.asyncio
+ # async def test_replicas(self):
+ # """Test service with replicas returns unique names and hashes per replica."""
+ # actor = await UtilActor.options(
+ # procs=1, num_replicas=2, with_gpus=False
+ # ).as_service()
+ # results = await actor.get_name.fanout()
+
+ # assert len(results) == 2
+ # assert len(set(results)) == 2 # All names are unique
+ # for name in results:
+ # assert name.startswith("UtilActor")
+ # assert name.endswith("_r0")
+
+ # # Extract hashes from names (format: ActorName_replicaIdx_hash_r0)
+ # hashes = [name.split("_")[-2] for name in results]
+ # assert hashes[0] != hashes[1] # Hashes are different between replicas
diff --git a/tests/unit_tests/rl/environments/__init__.py b/tests/unit_tests/rl/environments/__init__.py
deleted file mode 100644
index 2e41cd717..000000000
--- a/tests/unit_tests/rl/environments/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
diff --git a/tests/unit_tests/rl/environments/test_chat.py b/tests/unit_tests/rl/environments/test_chat.py
deleted file mode 100644
index 678e0a33f..000000000
--- a/tests/unit_tests/rl/environments/test_chat.py
+++ /dev/null
@@ -1,331 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-import unittest
-from typing import Any, Dict, List, Optional
-from unittest.mock import MagicMock
-
-import torch
-
-from forge.envs.chat import (
- ChatAction,
- ChatEnvironment,
- ChatObservation,
- ChatState,
- Message,
-)
-
-
-class MockTokenizer:
- """Mock tokenizer implementing TokenizerProtocol for testing."""
-
- def apply_chat_template(
- self,
- conversation: List[Dict[str, str]],
- tools: Optional[List[Dict]] = None,
- documents: Optional[List[Dict[str, str]]] = None,
- chat_template: Optional[str] = None,
- add_generation_prompt: bool = False,
- continue_final_message: bool = False,
- tokenize: bool = True,
- padding: bool = False,
- truncation: bool = False,
- max_length: Optional[int] = None,
- return_tensors: Optional[str] = None,
- return_dict: bool = False,
- return_assistant_tokens_mask: bool = False,
- tokenizer_kwargs: Optional[Dict[str, Any]] = None,
- **kwargs,
- ) -> torch.Tensor:
- """Mock implementation of apply_chat_template."""
- # For testing, we'll just return a tensor with a simple pattern based on the conversation
- # Each message contributes 10 tokens to the output
- return torch.tensor([[i for i in range(len(conversation) * 10)]])
-
- def decode(
- self,
- token_ids: Any,
- skip_special_tokens: bool = False,
- clean_up_tokenization_spaces: Optional[bool] = None,
- **kwargs,
- ) -> str:
- """Mock implementation of decode."""
- # For testing, we'll just convert the tensor to a string
- if isinstance(token_ids, torch.Tensor):
- return f"Decoded: {token_ids.tolist()}"
- return f"Decoded: {token_ids}"
-
-
-class TestChatAction(unittest.TestCase):
- """Test the ChatAction class."""
-
- def test_init(self):
- """Test initialization of ChatAction."""
- tokens = torch.tensor([1, 2, 3])
- action = ChatAction(tokens=tokens)
- self.assertTrue(torch.equal(action.tokens, tokens))
-
- def test_init_empty_tokens(self):
- """Test initialization with empty tokens raises ValueError."""
- with self.assertRaises(ValueError):
- ChatAction(tokens=torch.tensor([]))
-
-
-class TestChatState(unittest.TestCase):
- """Test the ChatState class."""
-
- def test_init(self):
- """Test initialization of ChatState."""
- state = ChatState()
- self.assertEqual(state.history_messages, [])
- self.assertEqual(state.history_tokens, [])
-
- def test_init_with_values(self):
- """Test initialization with provided values."""
- messages: List[Message] = [{"role": "user", "content": "Hello"}]
- tokens = [torch.tensor([1, 2, 3])]
- state = ChatState(history_messages=messages, history_tokens=tokens)
- self.assertEqual(state.history_messages, messages)
- self.assertEqual(state.history_tokens, tokens)
-
-
-class TestChatObservation(unittest.TestCase):
- """Test the ChatObservation class."""
-
- def test_init(self):
- """Test initialization of ChatObservation."""
- obs = ChatObservation()
- self.assertEqual(obs.messages, [])
- self.assertEqual(obs.tokens.numel(), 0)
- self.assertFalse(obs.done)
- self.assertIsNone(obs.reward)
- self.assertEqual(obs.metadata, {})
-
- def test_init_with_values(self):
- """Test initialization with provided values."""
- messages: List[Message] = [{"role": "user", "content": "Hello"}]
- tokens = torch.tensor([1, 2, 3])
- obs = ChatObservation(
- messages=messages,
- tokens=tokens,
- done=True,
- reward=1.0,
- metadata={"test": "value"},
- )
- self.assertEqual(obs.messages, messages)
- self.assertTrue(torch.equal(obs.tokens, tokens))
- self.assertTrue(obs.done)
- self.assertEqual(obs.reward, 1.0)
- self.assertEqual(obs.metadata, {"test": "value"})
-
-
-class TestChatEnvironment(unittest.TestCase):
- """Test the ChatEnvironment class."""
-
- def setUp(self):
- """Set up test fixtures."""
- self.tokenizer = MockTokenizer()
-
- def test_init_no_system_prompt(self):
- """Test initialization without system prompt."""
- env = ChatEnvironment(tokenizer=self.tokenizer)
- self.assertEqual(env._state.history_messages, [])
- self.assertEqual(env._state.history_tokens, [])
-
- def test_init_with_system_prompt(self):
- """Test initialization with system prompt."""
- env = ChatEnvironment(
- tokenizer=self.tokenizer,
- system_prompt="You are a helpful assistant",
- system_role="system",
- )
- self.assertEqual(len(env._state.history_messages), 1)
- self.assertEqual(env._state.history_messages[0]["role"], "system")
- self.assertEqual(
- env._state.history_messages[0]["content"], "You are a helpful assistant"
- )
- self.assertEqual(len(env._state.history_tokens), 1)
-
- def test_init_invalid_tokenizer(self):
- """Test initialization with invalid tokenizer."""
- # Create a mock with no attributes by setting spec=[]
- invalid_tokenizer = MagicMock(spec=[])
- with self.assertRaises(ValueError):
- ChatEnvironment(tokenizer=invalid_tokenizer)
-
- def test_reset_no_system_prompt(self):
- """Test reset without system prompt."""
- env = ChatEnvironment(tokenizer=self.tokenizer)
- # Add some history first
- env._state.history_messages = [{"role": "user", "content": "Hello"}] # type: ignore
- env._state.history_tokens = [torch.tensor([1, 2, 3])]
-
- # Reset should clear the history
- obs = env.reset()
- self.assertEqual(env._state.history_messages, [])
- self.assertEqual(env._state.history_tokens, [])
- self.assertEqual(obs.messages, [])
- self.assertEqual(obs.tokens.numel(), 0)
-
- def test_reset_with_system_prompt(self):
- """Test reset with system prompt."""
- env = ChatEnvironment(
- tokenizer=self.tokenizer,
- system_prompt="You are a helpful assistant",
- system_role="system",
- )
- # Add some history first
- env._state.history_messages = [
- {"role": "system", "content": "You are a helpful assistant"},
- {"role": "user", "content": "Hello"},
- ] # type: ignore
- env._state.history_tokens = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])]
-
- # Reset should clear the history and add the system prompt
- obs = env.reset()
- self.assertEqual(len(env._state.history_messages), 1)
- self.assertEqual(env._state.history_messages[0]["role"], "system")
- self.assertEqual(
- env._state.history_messages[0]["content"], "You are a helpful assistant"
- )
- self.assertEqual(len(env._state.history_tokens), 1)
- self.assertEqual(len(obs.messages), 1)
- self.assertEqual(obs.messages[0]["role"], "system")
- self.assertEqual(obs.messages[0]["content"], "You are a helpful assistant")
-
- def test_step(self):
- """Test step method."""
- env = ChatEnvironment(tokenizer=self.tokenizer)
- action = ChatAction(tokens=torch.tensor([1, 2, 3]))
-
- obs = env.step(action)
-
- # Check that the tokens were added to history
- self.assertEqual(len(env._state.history_tokens), 1)
- self.assertTrue(
- torch.equal(env._state.history_tokens[0], torch.tensor([1, 2, 3]))
- )
-
- # Check that the message was added to history with decoded content
- self.assertEqual(len(env._state.history_messages), 1)
- self.assertEqual(env._state.history_messages[0]["role"], "assistant")
- self.assertEqual(
- env._state.history_messages[0]["content"], "Decoded: [1, 2, 3]"
- )
-
- # Check the observation
- self.assertEqual(len(obs.messages), 1)
- self.assertEqual(obs.messages[0]["role"], "assistant")
- self.assertEqual(obs.messages[0]["content"], "Decoded: [1, 2, 3]")
-
- def test_create_observation(self):
- """Test _create_observation method."""
- env = ChatEnvironment(tokenizer=self.tokenizer)
- env._state.history_messages = [
- {"role": "system", "content": "You are a helpful assistant"},
- {"role": "user", "content": "Hello"},
- ] # type: ignore
- env._state.history_tokens = [
- torch.tensor([[1, 2, 3]]),
- torch.tensor([[4, 5, 6]]),
- ]
-
- obs = env._create_observation()
-
- # Check the observation
- self.assertEqual(len(obs.messages), 2)
- self.assertEqual(obs.messages[0]["role"], "system")
- self.assertEqual(obs.messages[0]["content"], "You are a helpful assistant")
- self.assertEqual(obs.messages[1]["role"], "user")
- self.assertEqual(obs.messages[1]["content"], "Hello")
-
- # Check that the tokens were concatenated
- self.assertEqual(obs.tokens.numel(), 6) # 2 tensors of size 3
-
- def test_create_observation_empty_history(self):
- """Test _create_observation method with empty history."""
- env = ChatEnvironment(tokenizer=self.tokenizer)
-
- obs = env._create_observation()
-
- # Check the observation
- self.assertEqual(obs.messages, [])
- self.assertEqual(obs.tokens.numel(), 0)
-
- def test_state_property(self):
- """Test state property."""
- env = ChatEnvironment(tokenizer=self.tokenizer)
- state = env.state
- self.assertIsInstance(state, ChatState)
- self.assertEqual(state.history_messages, [])
- self.assertEqual(state.history_tokens, [])
-
- def test_message_to_action(self):
- """Test message_to_action method."""
- env = ChatEnvironment(tokenizer=self.tokenizer)
- message: Message = {"role": "user", "content": "Hello"}
-
- action = env.message_to_action(message)
-
- self.assertIsInstance(action, ChatAction)
- self.assertEqual(
- action.tokens.numel(), 10
- ) # Mock tokenizer returns 10 tokens per message
-
- def test_message_to_action_missing_role(self):
- """Test message_to_action method with missing role."""
- env = ChatEnvironment(tokenizer=self.tokenizer)
- # We're intentionally creating an invalid message to test error handling
- message = {"content": "Hello"} # type: ignore
-
- with self.assertRaises(ValueError):
- # Using type: ignore because we're intentionally passing an invalid message
- env.message_to_action(message) # type: ignore
-
- def test_message_to_action_missing_content(self):
- """Test message_to_action method with missing content."""
- env = ChatEnvironment(tokenizer=self.tokenizer)
- # We're intentionally creating an invalid message to test error handling
- message = {"role": "user"} # type: ignore
-
- with self.assertRaises(ValueError):
- # Using type: ignore because we're intentionally passing an invalid message
- env.message_to_action(message) # type: ignore
-
- def test_message_to_action_none_content(self):
- """Test message_to_action method with None content."""
- env = ChatEnvironment(tokenizer=self.tokenizer)
- # We're intentionally creating an invalid message to test error handling
- message = {"role": "user", "content": None} # type: ignore
-
- with self.assertRaises(ValueError):
- # Using type: ignore because we're intentionally passing an invalid message
- env.message_to_action(message) # type: ignore
-
- def test_with_transform(self):
- """Test environment with a transform."""
-
- def transform(obs):
- obs.metadata["transformed"] = True
- obs.reward = 1.0
- return obs
-
- env = ChatEnvironment(tokenizer=self.tokenizer, transform=transform)
- action = ChatAction(tokens=torch.tensor([1, 2, 3]))
-
- obs = env.step(action)
-
- self.assertTrue(obs.metadata.get("transformed"))
- self.assertEqual(obs.reward, 1.0)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/unit_tests/test_coder.py b/tests/unit_tests/test_coder.py
new file mode 100644
index 000000000..6875114b7
--- /dev/null
+++ b/tests/unit_tests/test_coder.py
@@ -0,0 +1,128 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Unit tests for forge.actors.coder.SandboxedPythonCoder.
+"""
+
+import os
+import tempfile
+import uuid
+from unittest.mock import Mock, patch
+
+import pytest
+
+from forge.actors.coder import _SandboxedPythonCoder
+
+
+@pytest.mark.asyncio
+async def test_coder_success():
+ """Test successful execution."""
+ unique_id = str(uuid.uuid4())[:8]
+ container_name = f"test_sandbox_{unique_id}"
+
+ with tempfile.NamedTemporaryFile(suffix=".sqsh", delete=False) as temp_image:
+ image_path = temp_image.name
+
+ def mock_subprocess_run(*args, **kwargs):
+ """Mock subprocess.run for testing."""
+ cmd = args[0] if args else kwargs.get("args", [])
+ cmd_str = " ".join(cmd) if isinstance(cmd, list) else str(cmd)
+
+ if "import" in cmd_str:
+ result = Mock()
+ result.returncode = 0
+ result.stderr = ""
+ return result
+ elif "remove" in cmd_str:
+ result = Mock()
+ result.returncode = 0
+ return result
+ elif "create" in cmd_str:
+ result = Mock()
+ result.returncode = 0
+ result.stderr = ""
+ return result
+ elif "start" in cmd_str:
+ result = Mock()
+ result.returncode = 0
+ result.stdout = "Hello World\n"
+ result.stderr = ""
+ return result
+ else:
+ raise ValueError(f"Unexpected subprocess call: {cmd_str}")
+
+ try:
+ with patch(
+ "forge.actors.coder.subprocess.run", side_effect=mock_subprocess_run
+ ):
+ coder = _SandboxedPythonCoder(
+ docker_image="docker://python:3.10",
+ sqsh_image_path=image_path,
+ container_name=container_name,
+ )
+
+ await coder.setup()
+ result, _ = await coder.execute(code="print('Hello World')")
+ assert result == "Hello World\n"
+ finally:
+ if os.path.exists(image_path):
+ os.unlink(image_path)
+
+
+@pytest.mark.asyncio
+async def test_coder_execution_failure():
+ """Test execution failure."""
+ unique_id = str(uuid.uuid4())[:8]
+ container_name = f"test_sandbox_{unique_id}"
+
+ with tempfile.NamedTemporaryFile(suffix=".sqsh", delete=False) as temp_image:
+ image_path = temp_image.name
+
+ def mock_subprocess_run(*args, **kwargs):
+ """Mock subprocess.run for testing."""
+ cmd = args[0] if args else kwargs.get("args", [])
+ cmd_str = " ".join(cmd) if isinstance(cmd, list) else str(cmd)
+
+ if "import" in cmd_str:
+ result = Mock()
+ result.returncode = 0
+ result.stderr = ""
+ return result
+ elif "remove" in cmd_str:
+ result = Mock()
+ result.returncode = 0
+ return result
+ elif "create" in cmd_str:
+ result = Mock()
+ result.returncode = 0
+ result.stderr = ""
+ return result
+ elif "start" in cmd_str:
+ result = Mock()
+ result.returncode = 1
+ result.stdout = ""
+ result.stderr = "SyntaxError: invalid syntax"
+ return result
+ else:
+ raise ValueError(f"Unexpected subprocess call: {cmd_str}")
+
+ try:
+ with patch(
+ "forge.actors.coder.subprocess.run", side_effect=mock_subprocess_run
+ ):
+ coder = _SandboxedPythonCoder(
+ docker_image="docker://python:3.10",
+ sqsh_image_path=image_path,
+ container_name=container_name,
+ )
+
+ await coder.setup()
+ output, err = await coder.execute(code="invalid syntax")
+ assert "SyntaxError" in err
+ finally:
+ if os.path.exists(image_path):
+ os.unlink(image_path)
diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py
index 69cc7e2ed..64a00c759 100644
--- a/tests/unit_tests/test_config.py
+++ b/tests/unit_tests/test_config.py
@@ -8,7 +8,7 @@
import pytest
-from forge.cli.config import resolve_hf_hub_paths
+from forge.util.config import resolve_hf_hub_paths
from omegaconf import DictConfig, OmegaConf
@@ -39,7 +39,7 @@
({"level1": {"level2": {"model": "hf://deep/model"}}}, [("deep/model",)]),
],
)
-@patch("forge.cli.config.snapshot_download")
+@patch("forge.util.config.snapshot_download")
def test_hf_path_resolution(mock_download, config_data, expected_calls):
"""Test hf:// path resolution in various data structures."""
mock_download.return_value = "/fake/cache/model"
@@ -78,7 +78,7 @@ def test_non_hf_paths_unchanged(config_data):
# Cache behavior tests
-@patch("forge.cli.config.snapshot_download")
+@patch("forge.util.config.snapshot_download")
def test_cache_hit_scenario(mock_download):
"""Test behavior when model is already cached."""
mock_download.return_value = "/fake/cache/model"
@@ -93,7 +93,7 @@ def test_cache_hit_scenario(mock_download):
assert result.model == "/fake/cache/model"
-@patch("forge.cli.config.snapshot_download")
+@patch("forge.util.config.snapshot_download")
def test_cache_miss_scenario(mock_download):
"""Test behavior when model is not cached."""
from huggingface_hub.utils import LocalEntryNotFoundError
@@ -145,7 +145,7 @@ def test_invalid_hf_urls(invalid_hf_url, expected_error):
assert expected_error in str(exc_info.value)
-@patch("forge.cli.config.snapshot_download")
+@patch("forge.util.config.snapshot_download")
def test_download_failure_handling(mock_download):
"""Test error handling when download fails."""
mock_download.side_effect = Exception("Network error: Repository not found")
@@ -159,7 +159,7 @@ def test_download_failure_handling(mock_download):
# Integration test with mixed data types
-@patch("forge.cli.config.snapshot_download")
+@patch("forge.util.config.snapshot_download")
def test_complex_real_world_config(mock_download):
"""Test with a realistic complex configuration."""
mock_download.return_value = "/fake/cache/model"
diff --git a/tests/unit_tests/test_env_constants.py b/tests/unit_tests/test_env_constants.py
new file mode 100644
index 000000000..68d7806bb
--- /dev/null
+++ b/tests/unit_tests/test_env_constants.py
@@ -0,0 +1,116 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Unit tests for env_constants module."""
+
+import os
+
+from forge.env import all_env_vars, DISABLE_PERF_METRICS, EnvVar, FORGE_DISABLE_METRICS
+
+
+class TestEnvVarGetValue:
+ """Test the EnvVar.get_value() method."""
+
+ def test_get_value_returns_default_when_unset(self):
+ """Test get_value returns default when env var is not set."""
+ if "DISABLE_PERF_METRICS" in os.environ:
+ del os.environ["DISABLE_PERF_METRICS"]
+
+ value = DISABLE_PERF_METRICS.get_value()
+ assert value is False
+
+ def test_get_value_returns_env_value_when_set(self):
+ """Test get_value returns env var value when set."""
+ from forge.env import MONARCH_STDERR_LEVEL
+
+ os.environ["MONARCH_STDERR_LOG"] = "debug"
+
+ try:
+ value = MONARCH_STDERR_LEVEL.get_value()
+ assert value == "debug"
+ finally:
+ del os.environ["MONARCH_STDERR_LOG"]
+
+ def test_get_value_bool_auto_cast_with_true(self):
+ """Test get_value auto-casts 'true' to bool."""
+ os.environ["DISABLE_PERF_METRICS"] = "true"
+ try:
+ assert DISABLE_PERF_METRICS.get_value() is True
+ finally:
+ del os.environ["DISABLE_PERF_METRICS"]
+
+ def test_get_value_bool_auto_cast_with_one(self):
+ """Test get_value auto-casts '1' to bool."""
+ os.environ["DISABLE_PERF_METRICS"] = "1"
+ try:
+ assert DISABLE_PERF_METRICS.get_value() is True
+ finally:
+ del os.environ["DISABLE_PERF_METRICS"]
+
+ def test_get_value_bool_auto_cast_with_false(self):
+ """Test get_value auto-casts 'false' to bool."""
+ os.environ["DISABLE_PERF_METRICS"] = "false"
+ try:
+ assert DISABLE_PERF_METRICS.get_value() is False
+ finally:
+ del os.environ["DISABLE_PERF_METRICS"]
+
+
+class TestPredefinedConstants:
+ """Test the predefined environment variable constants."""
+
+ def test_predefined_constants_structure(self):
+ """Test that predefined constants are properly defined."""
+ assert isinstance(DISABLE_PERF_METRICS, EnvVar)
+ assert DISABLE_PERF_METRICS.name == "DISABLE_PERF_METRICS"
+ assert DISABLE_PERF_METRICS.default is False
+
+ assert isinstance(FORGE_DISABLE_METRICS, EnvVar)
+ assert FORGE_DISABLE_METRICS.name == "FORGE_DISABLE_METRICS"
+ assert FORGE_DISABLE_METRICS.default is False
+
+ def test_predefined_constants_work_with_get_value(self):
+ """Test that predefined constants work with get_value method."""
+ if DISABLE_PERF_METRICS.name in os.environ:
+ del os.environ[DISABLE_PERF_METRICS.name]
+
+ assert DISABLE_PERF_METRICS.get_value() is False
+
+ os.environ[DISABLE_PERF_METRICS.name] = "true"
+ try:
+ assert DISABLE_PERF_METRICS.get_value() is True
+ finally:
+ del os.environ[DISABLE_PERF_METRICS.name]
+
+
+class TestAllEnvVars:
+ """Test the all_env_vars() function."""
+
+ def test_all_env_vars_returns_list(self):
+ """Test that all_env_vars returns a list."""
+ env_vars = all_env_vars()
+ assert isinstance(env_vars, list)
+
+ def test_all_env_vars_contains_only_env_var_instances(self):
+ """Test that all_env_vars returns only EnvVar instances."""
+ env_vars = all_env_vars()
+ assert len(env_vars) > 0
+ assert all(isinstance(var, EnvVar) for var in env_vars)
+
+ def test_all_env_vars_contains_expected_constants(self):
+ """Test that all_env_vars includes known constants."""
+ env_vars = all_env_vars()
+ env_var_names = {var.name for var in env_vars}
+
+ assert "DISABLE_PERF_METRICS" in env_var_names
+ assert "FORGE_DISABLE_METRICS" in env_var_names
+ assert "MONARCH_STDERR_LOG" in env_var_names
+
+ def test_all_env_vars_can_iterate_and_get_values(self):
+ """Test that all_env_vars can be used to iterate and get values."""
+ for env_var in all_env_vars():
+ value = env_var.get_value()
+ assert value is not None or env_var.default is None
diff --git a/tests/unit_tests/test_generator_config.py b/tests/unit_tests/test_generator_config.py
new file mode 100644
index 000000000..94cb58859
--- /dev/null
+++ b/tests/unit_tests/test_generator_config.py
@@ -0,0 +1,137 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import tempfile
+import unittest
+
+import pytest
+import yaml
+
+
+def _import_error():
+ """Check if there are import errors that would cause CI failures."""
+ try:
+ import forge.actors.generator # noqa: F401
+
+ return False
+ except ImportError:
+ return True
+
+
+class TestGeneratorConfig(unittest.TestCase):
+ """Test suite for Generator configuration handling after PolicyConfig removal."""
+
+ @pytest.mark.skipif(
+ _import_error(),
+ reason="Import error, likely due to missing dependencies on CI.",
+ )
+ def test_generator_default_initialization(self):
+ """Generator initializes with default values."""
+ from forge.actors.generator import Generator
+ from vllm.engine.arg_utils import EngineArgs
+ from vllm.sampling_params import SamplingParams
+
+ generator = Generator()
+
+ # Default factories
+ self.assertIsInstance(generator.engine_args, EngineArgs)
+ self.assertIsInstance(generator.sampling_params, SamplingParams)
+
+ # Worker defaults
+ self.assertEqual(generator.engine_args.model, "Qwen/Qwen3-0.6B")
+ self.assertEqual(generator.engine_args.tensor_parallel_size, 1)
+ self.assertEqual(generator.engine_args.pipeline_parallel_size, 1)
+ self.assertFalse(generator.engine_args.enforce_eager)
+ self.assertTrue(generator.engine_args._is_v1_supported_oracle())
+
+ # Sampling defaults
+ self.assertEqual(generator.sampling_params.n, 1)
+ self.assertFalse(generator.sampling_params.guided_decoding)
+ self.assertEqual(generator.sampling_params.max_tokens, 16)
+
+ @pytest.mark.skipif(
+ _import_error(),
+ reason="Import error, likely due to missing dependencies on CI.",
+ )
+ def test_generator_with_dict_configs(self):
+ from forge.actors.generator import Generator
+ from vllm.engine.arg_utils import EngineArgs
+ from vllm.sampling_params import SamplingParams
+
+ engine_dict = {
+ "model": "Qwen/Qwen3-0.6B",
+ "tensor_parallel_size": 1,
+ "pipeline_parallel_size": 1,
+ "enforce_eager": True,
+ "gpu_memory_utilization": 0.1,
+ "max_model_len": 1024,
+ }
+
+ sampling_dict = {
+ "n": 2,
+ "max_tokens": 32,
+ }
+
+ generator = Generator(
+ engine_args=engine_dict,
+ sampling_params=sampling_dict,
+ )
+
+ self.assertIsInstance(generator.engine_args, EngineArgs)
+ self.assertIsInstance(generator.sampling_params, SamplingParams)
+
+ # Test basic fields
+ self.assertEqual(generator.engine_args.model, "Qwen/Qwen3-0.6B")
+ self.assertEqual(generator.engine_args.tensor_parallel_size, 1)
+ self.assertEqual(generator.engine_args.pipeline_parallel_size, 1)
+ self.assertEqual(generator.engine_args.gpu_memory_utilization, 0.1)
+ self.assertEqual(generator.engine_args.max_model_len, 1024)
+ self.assertTrue(generator.engine_args.enforce_eager)
+ self.assertTrue(generator.engine_args._is_v1_supported_oracle())
+
+ self.assertEqual(generator.sampling_params.n, 2)
+ self.assertEqual(generator.sampling_params.max_tokens, 32)
+
+ @pytest.mark.skipif(
+ _import_error(),
+ reason="Import error, likely due to missing dependencies on CI.",
+ )
+ def test_generator_yaml_config_loading(self):
+ """Generator can be constructed from a YAML config file."""
+ from forge.actors.generator import Generator
+
+ yaml_content = """
+ engine_args:
+ model: "Qwen/Qwen3-0.6B"
+ tensor_parallel_size: 1
+ pipeline_parallel_size: 1
+ enforce_eager: true
+
+ sampling_params:
+ n: 2
+ max_tokens: 32
+ """
+
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
+ f.write(yaml_content)
+ f.flush()
+
+ with open(f.name, "r") as yaml_file:
+ config = yaml.safe_load(yaml_file)
+
+ generator = Generator(**config)
+ self.assertEqual(generator.engine_args.model, "Qwen/Qwen3-0.6B")
+ self.assertEqual(generator.engine_args.tensor_parallel_size, 1)
+ self.assertEqual(generator.engine_args.pipeline_parallel_size, 1)
+ self.assertTrue(generator.engine_args.enforce_eager)
+ self.assertTrue(generator.engine_args._is_v1_supported_oracle())
+
+ self.assertEqual(generator.sampling_params.n, 2)
+ self.assertEqual(generator.sampling_params.max_tokens, 32)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/unit_tests/test_gpu_manager.py b/tests/unit_tests/test_gpu_manager.py
deleted file mode 100644
index cb99e903e..000000000
--- a/tests/unit_tests/test_gpu_manager.py
+++ /dev/null
@@ -1,264 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-"""Tests for GPU manager functionality."""
-
-import pytest
-from forge.controller.system_controllers.gpu_manager import GpuManager
-from monarch.actor import ActorError, this_host
-
-
-@pytest.mark.timeout(10)
-@pytest.mark.asyncio
-async def test_initialization():
- """Test GPU manager initialization."""
- p = this_host().spawn_procs(per_host={"cpus": 1})
- manager = p.spawn("GpuManager", GpuManager)
- available_gpus = await manager.get_available_gpus.call_one()
-
- # Should have 8 GPUs available by default
- assert available_gpus == [str(i) for i in range(8)]
- assert len(available_gpus) == 8
-
-
-@pytest.mark.timeout(10)
-@pytest.mark.asyncio
-async def test_get_gpus_basic():
- """Test basic GPU allocation."""
- p = this_host().spawn_procs(per_host={"cpus": 1})
- manager = p.spawn("GpuManager", GpuManager)
-
- # Request 2 GPUs
- result = await manager.get_gpus.call_one(2)
-
- # Should return 2 GPU IDs as strings
- assert len(result) == 2
- assert all(isinstance(gpu_id, str) for gpu_id in result)
-
- # Should be valid GPU IDs (0-7)
- gpu_ints = [int(gpu_id) for gpu_id in result]
- assert all(0 <= gpu_id <= 7 for gpu_id in gpu_ints)
-
- # Check remaining available GPUs
- available_gpus = await manager.get_available_gpus.call_one()
- assert len(available_gpus) == 6
-
- # Allocated GPUs should no longer be available
- for gpu_id in gpu_ints:
- assert str(gpu_id) not in available_gpus
-
-
-@pytest.mark.timeout(10)
-@pytest.mark.asyncio
-async def test_get_gpus_all():
- """Test allocating all available GPUs."""
- p = this_host().spawn_procs(per_host={"cpus": 1})
- manager = p.spawn("GpuManager", GpuManager)
-
- # Request all 8 GPUs
- result = await manager.get_gpus.call_one(8)
-
- assert len(result) == 8
-
- # Check no GPUs are available
- available_gpus = await manager.get_available_gpus.call_one()
- assert len(available_gpus) == 0
-
- # All original GPUs should be allocated
- allocated_ints = {int(gpu_id) for gpu_id in result}
- assert allocated_ints == set(range(8))
-
-
-@pytest.mark.timeout(10)
-@pytest.mark.asyncio
-async def test_get_gpus_insufficient():
- """Test error when requesting more GPUs than available."""
- p = this_host().spawn_procs(per_host={"cpus": 1})
- manager = p.spawn("GpuManager", GpuManager)
-
- # Request more than 8 GPUs should raise an error
- with pytest.raises(ActorError, match="Not enough GPUs available"):
- await manager.get_gpus.call_one(9)
-
- # Available GPUs should remain unchanged
- available_gpus = await manager.get_available_gpus.call_one()
- assert available_gpus == [str(i) for i in range(8)]
-
-
-@pytest.mark.timeout(10)
-@pytest.mark.asyncio
-async def test_get_gpus_zero():
- """Test requesting zero GPUs."""
- p = this_host().spawn_procs(per_host={"cpus": 1})
- manager = p.spawn("GpuManager", GpuManager)
-
- result = await manager.get_gpus.call_one(0)
-
- assert result == []
-
- # Available GPUs should remain unchanged
- available_gpus = await manager.get_available_gpus.call_one()
- assert available_gpus == [str(i) for i in range(8)]
-
-
-@pytest.mark.timeout(10)
-@pytest.mark.asyncio
-async def test_release_gpus_basic():
- """Test basic GPU release functionality."""
- p = this_host().spawn_procs(per_host={"cpus": 1})
- manager = p.spawn("GpuManager", GpuManager)
-
- # Allocate some GPUs
- allocated = await manager.get_gpus.call_one(3)
-
- # Check they're no longer available
- available_gpus = await manager.get_available_gpus.call_one()
- assert len(available_gpus) == 5
-
- # Release them back
- await manager.release_gpus.call_one(allocated)
-
- # Should have all 8 GPUs available again
- available_gpus = await manager.get_available_gpus.call_one()
- assert len(available_gpus) == 8
- assert set(available_gpus) == {str(i) for i in range(8)}
-
-
-@pytest.mark.timeout(10)
-@pytest.mark.asyncio
-async def test_release_gpus_partial():
- """Test releasing only some of the allocated GPUs."""
- p = this_host().spawn_procs(per_host={"cpus": 1})
- manager = p.spawn("GpuManager", GpuManager)
-
- # Allocate 4 GPUs
- allocated = await manager.get_gpus.call_one(4)
-
- # Check available count
- available_gpus = await manager.get_available_gpus.call_one()
- assert len(available_gpus) == 4
-
- # Release only 2 of them
- to_release = allocated[:2]
- await manager.release_gpus.call_one(to_release)
-
- # Should have 6 GPUs available now
- available_gpus = await manager.get_available_gpus.call_one()
- assert len(available_gpus) == 6
-
- # The released GPUs should be back in available set
- available_ints = {int(gpu_id) for gpu_id in available_gpus}
- for gpu_id in to_release:
- assert int(gpu_id) in available_ints
-
-
-@pytest.mark.timeout(10)
-@pytest.mark.asyncio
-async def test_release_gpus_empty():
- """Test releasing empty list of GPUs."""
- p = this_host().spawn_procs(per_host={"cpus": 1})
- manager = p.spawn("GpuManager", GpuManager)
-
- await manager.release_gpus.call_one([])
-
- # Should remain unchanged
- available_gpus = await manager.get_available_gpus.call_one()
- assert available_gpus == [str(i) for i in range(8)]
-
-
-@pytest.mark.timeout(10)
-@pytest.mark.asyncio
-async def test_allocation_release_cycle():
- """Test multiple allocation and release cycles."""
- p = this_host().spawn_procs(per_host={"cpus": 1})
- manager = p.spawn("GpuManager", GpuManager)
-
- # Cycle 1: Allocate 3, release 3
- batch1 = await manager.get_gpus.call_one(3)
- available_gpus = await manager.get_available_gpus.call_one()
- assert len(available_gpus) == 5
-
- await manager.release_gpus.call_one(batch1)
- available_gpus = await manager.get_available_gpus.call_one()
- assert len(available_gpus) == 8
-
- # Cycle 2: Allocate 5, release 5
- batch2 = await manager.get_gpus.call_one(5)
- available_gpus = await manager.get_available_gpus.call_one()
- assert len(available_gpus) == 3
-
- await manager.release_gpus.call_one(batch2)
- available_gpus = await manager.get_available_gpus.call_one()
- assert len(available_gpus) == 8
-
- # Should be back to original state
- assert set(available_gpus) == {str(i) for i in range(8)}
-
-
-@pytest.mark.timeout(10)
-@pytest.mark.asyncio
-async def test_incremental_allocation():
- """Test incremental allocation until exhaustion."""
- p = this_host().spawn_procs(per_host={"cpus": 1})
- manager = p.spawn("GpuManager", GpuManager)
- all_allocated = []
-
- # Allocate in chunks
- batch1 = await manager.get_gpus.call_one(2) # 6 remaining
- all_allocated.extend(batch1)
- available_gpus = await manager.get_available_gpus.call_one()
- assert len(available_gpus) == 6
-
- batch2 = await manager.get_gpus.call_one(3) # 3 remaining
- all_allocated.extend(batch2)
- available_gpus = await manager.get_available_gpus.call_one()
- assert len(available_gpus) == 3
-
- batch3 = await manager.get_gpus.call_one(3) # 0 remaining
- all_allocated.extend(batch3)
- available_gpus = await manager.get_available_gpus.call_one()
- assert len(available_gpus) == 0
-
- # Should have allocated all 8 GPUs
- assert len(all_allocated) == 8
-
- # Should fail to allocate more
- with pytest.raises(ActorError):
- await manager.get_gpus.call_one(1)
-
-
-@pytest.mark.timeout(10)
-@pytest.mark.asyncio
-async def test_concurrent_operations_simulation():
- """Test simulated concurrent operations."""
- p = this_host().spawn_procs(per_host={"cpus": 1})
- manager = p.spawn("GpuManager", GpuManager)
-
- # Simulate multiple "jobs" allocating and releasing
- job1_gpus = await manager.get_gpus.call_one(2)
- job2_gpus = await manager.get_gpus.call_one(3)
-
- available_gpus = await manager.get_available_gpus.call_one()
- assert len(available_gpus) == 3
-
- # Job1 releases its GPUs
- await manager.release_gpus.call_one(job1_gpus)
- available_gpus = await manager.get_available_gpus.call_one()
- assert len(available_gpus) == 5
-
- # Job3 allocates some GPUs
- job3_gpus = await manager.get_gpus.call_one(4)
- available_gpus = await manager.get_available_gpus.call_one()
- assert len(available_gpus) == 1
-
- # Job2 and Job3 release
- await manager.release_gpus.call_one(job2_gpus)
- await manager.release_gpus.call_one(job3_gpus)
-
- # Should be back to full capacity
- available_gpus = await manager.get_available_gpus.call_one()
- assert len(available_gpus) == 8
- assert set(available_gpus) == {str(i) for i in range(8)}
diff --git a/tests/unit_tests/test_policy_config.py b/tests/unit_tests/test_policy_config.py
deleted file mode 100644
index 08de4f907..000000000
--- a/tests/unit_tests/test_policy_config.py
+++ /dev/null
@@ -1,179 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-import tempfile
-import unittest
-
-import pytest
-import yaml
-
-
-def _import_error():
- """Check if there are import errors that would cause CI failures."""
- try:
- import forge.actors.policy # noqa: F401
-
- return False
- except ImportError:
- return True
-
-
-class TestPolicyConfig(unittest.TestCase):
- """Test suite for Policy configuration handling after PolicyConfig removal."""
-
- @pytest.mark.skipif(
- _import_error(),
- reason="Import error, likely due to missing dependencies on CI.",
- )
- def test_policy_default_initialization(self):
- """Policy initializes with default values."""
- from forge.actors.policy import EngineConfig, Policy, SamplingConfig
-
- policy = Policy()
-
- # Default factories
- self.assertIsInstance(policy.engine_config, EngineConfig)
- self.assertIsInstance(policy.sampling_config, SamplingConfig)
- self.assertIsNone(policy.available_devices)
-
- # Worker defaults
- self.assertEqual(policy.engine_config.model, "meta-llama/Llama-3.1-8B-Instruct")
- self.assertEqual(policy.engine_config.tensor_parallel_size, 1)
- self.assertEqual(policy.engine_config.pipeline_parallel_size, 1)
- self.assertFalse(policy.engine_config.enforce_eager)
- self.assertTrue(policy.engine_config._is_v1_supported_oracle())
-
- # Sampling defaults
- self.assertEqual(policy.sampling_config.n, 1)
- self.assertFalse(policy.sampling_config.guided_decoding)
- self.assertEqual(policy.sampling_config.max_tokens, 512)
-
- @pytest.mark.skipif(
- _import_error(),
- reason="Import error, likely due to missing dependencies on CI.",
- )
- def test_policy_with_dict_configs(self):
- """Policy accepts dicts for engine_config and sampling_config, including nested dicts."""
- from forge.actors.policy import EngineConfig, Policy, SamplingConfig
-
- # Test with nested dict structure
- engine_dict = {
- "model": "test-model-6789",
- "tensor_parallel_size": 7777,
- "pipeline_parallel_size": 8888,
- "enforce_eager": True,
- "nested_config": {
- "gpu_memory_utilization": 0.9,
- "max_model_len": 4096,
- "custom_settings": {"temperature": 0.7, "top_p": 0.9},
- },
- }
-
- sampling_dict = {
- "n": 1357,
- "guided_decoding": True,
- "max_tokens": 2468,
- }
-
- policy = Policy(
- engine_config=engine_dict,
- sampling_config=sampling_dict,
- available_devices="test-gpu-device-abcd",
- )
-
- self.assertIsInstance(policy.engine_config, EngineConfig)
- self.assertIsInstance(policy.sampling_config, SamplingConfig)
-
- # Test basic fields
- self.assertEqual(policy.engine_config.model, "test-model-6789")
- self.assertEqual(policy.engine_config.tensor_parallel_size, 7777)
- self.assertEqual(policy.engine_config.pipeline_parallel_size, 8888)
- self.assertTrue(policy.engine_config.enforce_eager)
- self.assertTrue(policy.engine_config._is_v1_supported_oracle())
-
- self.assertEqual(policy.sampling_config.n, 1357)
- # After __post_init__, guided_decoding becomes GuidedDecodingParams object when True
- self.assertIsNotNone(policy.sampling_config.guided_decoding)
- self.assertEqual(policy.sampling_config.max_tokens, 2468)
-
- # Test that engine_dict accepts and preserves nested dict structure
- # The original engine_dict should remain unchanged and accessible
- self.assertIn("nested_config", engine_dict)
- self.assertEqual(engine_dict["nested_config"]["gpu_memory_utilization"], 0.9)
- self.assertEqual(engine_dict["nested_config"]["max_model_len"], 4096)
- self.assertEqual(
- engine_dict["nested_config"]["custom_settings"]["temperature"], 0.7
- )
- self.assertEqual(engine_dict["nested_config"]["custom_settings"]["top_p"], 0.9)
-
- @pytest.mark.skipif(
- _import_error(),
- reason="Import error, likely due to missing dependencies on CI.",
- )
- def test_policy_yaml_config_loading(self):
- """Policy can be constructed from a YAML config file."""
- from forge.actors.policy import Policy
-
- yaml_content = """
- engine_config:
- model: "yaml-test-model-9876"
- tensor_parallel_size: 1234
- pipeline_parallel_size: 5678
- enforce_eager: true
-
- sampling_config:
- n: 2468
- guided_decoding: true
- max_tokens: 1357
-
- available_devices: "yaml-test-device-xyz"
- """
-
- with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
- f.write(yaml_content)
- f.flush()
-
- with open(f.name, "r") as yaml_file:
- config = yaml.safe_load(yaml_file)
-
- policy = Policy(**config)
-
- self.assertEqual(policy.engine_config.model, "yaml-test-model-9876")
- self.assertEqual(policy.engine_config.tensor_parallel_size, 1234)
- self.assertEqual(policy.engine_config.pipeline_parallel_size, 5678)
- self.assertTrue(policy.engine_config.enforce_eager)
- self.assertTrue(policy.engine_config._is_v1_supported_oracle())
-
- self.assertEqual(policy.sampling_config.n, 2468)
- # After __post_init__, guided_decoding becomes GuidedDecodingParams object when True
- self.assertIsNotNone(policy.sampling_config.guided_decoding)
- self.assertEqual(policy.sampling_config.max_tokens, 1357)
-
- self.assertEqual(policy.available_devices, "yaml-test-device-xyz")
-
- @pytest.mark.skipif(
- _import_error(),
- reason="Import error, likely due to missing dependencies on CI.",
- )
- def test_engineconfig_ignores_invalid_keys(self):
- """EngineConfig.from_dict ignores unexpected keys."""
- from forge.actors.policy import EngineConfig
-
- engine_config = {
- "model": "custom-model",
- "tensor_parallel_size": 2,
- "invalid_key_123": "should be ignored",
- }
-
- config = EngineConfig.from_dict(engine_config)
-
- self.assertEqual(config.model, "custom-model")
- self.assertEqual(config.tensor_parallel_size, 2)
- self.assertFalse(hasattr(config, "invalid_key_123"))
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/unit_tests/test_provisioner.py b/tests/unit_tests/test_provisioner.py
index 888e75e3c..10686412d 100644
--- a/tests/unit_tests/test_provisioner.py
+++ b/tests/unit_tests/test_provisioner.py
@@ -45,9 +45,6 @@ def test_gpu_manager_invalid_device_range(self):
with pytest.raises(AssertionError):
GpuManager(available_devices={-1}) # Negative device
- with pytest.raises(AssertionError):
- GpuManager(available_devices={8}) # Device >= 8
-
with pytest.raises(AssertionError):
GpuManager(available_devices={"0"}) # String instead of int
@@ -90,7 +87,8 @@ class TestProvisionerCudaVisibleDevices:
"""Test Provisioner's handling of CUDA_VISIBLE_DEVICES environment variable."""
@mock.patch.dict(os.environ, {}, clear=True)
- def test_provisioner_no_cuda_visible_devices(self):
+ @mock.patch("torch.cuda.device_count", return_value=8)
+ def test_provisioner_no_cuda_visible_devices(self, mock_device_count):
"""Test Provisioner when CUDA_VISIBLE_DEVICES is not set."""
provisioner = Provisioner()
@@ -135,7 +133,8 @@ def test_provisioner_duplicate_gpu_ids(self):
assert len(available) == 3
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": ""}, clear=True)
- def test_provisioner_empty_cuda_visible_devices(self):
+ @mock.patch("torch.cuda.device_count", return_value=8)
+ def test_provisioner_empty_cuda_visible_devices(self, mock_device_count):
"""Test Provisioner with empty CUDA_VISIBLE_DEVICES."""
provisioner = Provisioner()
@@ -144,7 +143,7 @@ def test_provisioner_empty_cuda_visible_devices(self):
available = local_gpu_manager.get_available_gpus()
assert available == [str(i) for i in range(8)]
- @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1,2"}, clear=True)
+ @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1,2"}, clear=False)
@pytest.mark.asyncio
async def test_get_proc_mesh_respects_cuda_visible_devices(self):
"""Test that get_proc_mesh uses CUDA_VISIBLE_DEVICES for local allocation."""
@@ -161,6 +160,8 @@ async def test_get_proc_mesh_respects_cuda_visible_devices(self):
num_procs=2,
with_gpus=True,
num_hosts=None,
+ port="12345",
+ addr="localhost",
)
# Verify GPUs were allocated from available set
remaining_available = local_gpu_manager.get_available_gpus()
@@ -243,3 +244,47 @@ def test_single_gpu_scenario(self):
# Release and verify
local_gpu_manager.release_gpus(allocated)
assert local_gpu_manager.get_available_gpus() == ["0"]
+
+
+class TestDynamicGpuDetection:
+ """Test dynamic GPU detection using torch.cuda.device_count()."""
+
+ @mock.patch.dict(os.environ, {}, clear=True)
+ @mock.patch("torch.cuda.device_count", return_value=4)
+ def test_provisioner_with_4_gpus(self, mock_device_count):
+ """Test Provisioner detects 4 GPUs when torch.cuda.device_count() returns 4."""
+ provisioner = Provisioner()
+
+ local_gpu_manager = provisioner._host_gpu_map[provisioner._this_host_id]
+ available = local_gpu_manager.get_available_gpus()
+ assert sorted(available) == ["0", "1", "2", "3"]
+ assert len(available) == 4
+ assert local_gpu_manager.max_device_count == 4
+
+ @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,2,4"}, clear=True)
+ @mock.patch("torch.cuda.device_count", return_value=8)
+ def test_cuda_visible_devices_with_detected_gpus(self, mock_device_count):
+ """Test that CUDA_VISIBLE_DEVICES works correctly with detected GPU count."""
+ provisioner = Provisioner()
+
+ local_gpu_manager = provisioner._host_gpu_map[provisioner._this_host_id]
+ available = local_gpu_manager.get_available_gpus()
+ # Should use CUDA_VISIBLE_DEVICES, not all 8 detected GPUs
+ assert sorted(available) == ["0", "2", "4"]
+ assert len(available) == 3
+ # max_device_count should still be 8 from detection
+ assert local_gpu_manager.max_device_count == 8
+
+ @mock.patch.dict(os.environ, {}, clear=True)
+ @mock.patch(
+ "torch.cuda.device_count", side_effect=RuntimeError("CUDA not available")
+ )
+ def test_provisioner_when_cuda_unavailable(self, mock_device_count):
+ """Test Provisioner defaults to 0 GPUs when CUDA is not available."""
+ provisioner = Provisioner()
+
+ local_gpu_manager = provisioner._host_gpu_map[provisioner._this_host_id]
+ available = local_gpu_manager.get_available_gpus()
+ assert available == []
+ assert len(available) == 0
+ assert local_gpu_manager.max_device_count == 0
diff --git a/tests/unit_tests/test_replay_buffer.py b/tests/unit_tests/test_replay_buffer.py
index f6c5bc574..10053b78f 100644
--- a/tests/unit_tests/test_replay_buffer.py
+++ b/tests/unit_tests/test_replay_buffer.py
@@ -6,16 +6,31 @@
"""Test for data/replay_buffer.py"""
+from dataclasses import dataclass
+
import pytest
import pytest_asyncio
from forge.actors.replay_buffer import ReplayBuffer
-from forge.types import Trajectory
+
+
+@dataclass
+class TestEpisode:
+ """
+ Dummy Episode containing just a policy version
+
+ ReplayBuffer expects any construct (typically an Episode) that contains a
+ `policy_version`.
+
+ TODO: Replaced with a unified interface in the future.
+ """
+
+ policy_version: int
class TestReplayBuffer:
@pytest_asyncio.fixture
async def replay_buffer(self) -> ReplayBuffer:
- replay_buffer = await ReplayBuffer.options(procs=1, with_gpus=True).as_actor(
+ replay_buffer = await ReplayBuffer.options(procs=1, with_gpus=False).as_actor(
batch_size=2, max_policy_age=1
)
await replay_buffer.setup.call()
@@ -23,27 +38,27 @@ async def replay_buffer(self) -> ReplayBuffer:
@pytest.mark.asyncio
async def test_add(self, replay_buffer: ReplayBuffer) -> None:
- trajectory = Trajectory(policy_version=0)
- await replay_buffer.add.call_one(trajectory)
+ episode = TestEpisode(policy_version=0)
+ await replay_buffer.add.call_one(episode)
assert replay_buffer._numel.call_one().get() == 1
- assert replay_buffer._getitem.call_one(0).get() == trajectory
+ assert replay_buffer._getitem.call_one(0).get() == episode
replay_buffer.clear.call_one().get()
@pytest.mark.asyncio
async def test_add_multiple(self, replay_buffer) -> None:
- trajectory_0 = Trajectory(policy_version=0)
- trajectory_1 = Trajectory(policy_version=1)
- await replay_buffer.add.call_one(trajectory_0)
- await replay_buffer.add.call_one(trajectory_1)
+ episode_0 = TestEpisode(policy_version=0)
+ episode_1 = TestEpisode(policy_version=1)
+ await replay_buffer.add.call_one(episode_0)
+ await replay_buffer.add.call_one(episode_1)
assert replay_buffer._numel.call_one().get() == 2
- assert replay_buffer._getitem.call_one(0).get() == trajectory_0
- assert replay_buffer._getitem.call_one(1).get() == trajectory_1
+ assert replay_buffer._getitem.call_one(0).get() == episode_0
+ assert replay_buffer._getitem.call_one(1).get() == episode_1
replay_buffer.clear.call_one().get()
@pytest.mark.asyncio
async def test_state_dict_save_load(self, replay_buffer) -> None:
- trajectory = Trajectory(policy_version=0)
- await replay_buffer.add.call_one(trajectory)
+ episode = TestEpisode(policy_version=0)
+ await replay_buffer.add.call_one(episode)
state_dict = replay_buffer.state_dict.call_one().get()
replay_buffer.clear.call_one().get()
assert replay_buffer._numel.call_one().get() == 0
@@ -53,10 +68,10 @@ async def test_state_dict_save_load(self, replay_buffer) -> None:
@pytest.mark.asyncio
async def test_evict(self, replay_buffer) -> None:
- trajectory_0 = Trajectory(policy_version=0)
- trajectory_1 = Trajectory(policy_version=1)
- await replay_buffer.add.call_one(trajectory_0)
- await replay_buffer.add.call_one(trajectory_1)
+ episode_0 = TestEpisode(policy_version=0)
+ episode_1 = TestEpisode(policy_version=1)
+ await replay_buffer.add.call_one(episode_0)
+ await replay_buffer.add.call_one(episode_1)
assert replay_buffer._numel.call_one().get() == 2
await replay_buffer.evict.call_one(curr_policy_version=2)
assert replay_buffer._numel.call_one().get() == 1
@@ -64,61 +79,56 @@ async def test_evict(self, replay_buffer) -> None:
@pytest.mark.asyncio
async def test_sample(self, replay_buffer) -> None:
- trajectory_0 = Trajectory(policy_version=0)
- trajectory_1 = Trajectory(policy_version=1)
- await replay_buffer.add.call_one(trajectory_0)
- await replay_buffer.add.call_one(trajectory_1)
+ episode_0 = TestEpisode(policy_version=0)
+ episode_1 = TestEpisode(policy_version=1)
+ await replay_buffer.add.call_one(episode_0)
+ await replay_buffer.add.call_one(episode_1)
assert replay_buffer._numel.call_one().get() == 2
- # Test a simple sampling w/ no evictions
+ # Test a simple sampling
samples = await replay_buffer.sample.call_one(curr_policy_version=1)
assert samples is not None
assert len(samples[0]) == 2
+ assert replay_buffer._numel.call_one().get() == 2
- # Test sampling with overriding batch size
- await replay_buffer.add.call_one(trajectory_0)
- samples = await replay_buffer.sample.call_one(
- curr_policy_version=1, batch_size=1
- )
- assert samples is not None
- assert len(samples[0]) == 1
-
- # Test sampling w/ overriding batch size (not enough samples in buffer, returns None)
- await replay_buffer.add.call_one(trajectory_0)
- samples = await replay_buffer.sample.call_one(
- curr_policy_version=1, batch_size=3
- )
+ # Test sampling (not enough samples in buffer, returns None)
+ await replay_buffer.add.call_one(episode_0)
+ samples = await replay_buffer.sample.call_one(curr_policy_version=1)
assert samples is None
replay_buffer.clear.call_one().get()
@pytest.mark.asyncio
async def test_sample_with_evictions(self, replay_buffer) -> None:
- trajectory_0 = Trajectory(policy_version=0)
- trajectory_1 = Trajectory(policy_version=1)
- await replay_buffer.add.call_one(trajectory_0)
- await replay_buffer.add.call_one(trajectory_1)
- assert replay_buffer._numel.call_one().get() == 2
+ episode_0 = TestEpisode(policy_version=0)
+ episode_1 = TestEpisode(policy_version=1)
+ episode_2 = TestEpisode(policy_version=2)
+ await replay_buffer.add.call_one(episode_0)
+ await replay_buffer.add.call_one(episode_1)
+ await replay_buffer.add.call_one(episode_2)
+ assert replay_buffer._numel.call_one().get() == 3
samples = await replay_buffer.sample.call_one(
- curr_policy_version=2, batch_size=1
+ curr_policy_version=2,
)
assert samples is not None
- assert len(samples[0]) == 1
- assert samples[0][0] == trajectory_1
+ assert len(samples[0]) == 2
+ assert samples[0][0].policy_version > 0
+ assert samples[0][1].policy_version > 0
+ assert replay_buffer._numel.call_one().get() == 2
replay_buffer.clear.call_one().get()
@pytest.mark.asyncio
async def test_sample_dp_size(self) -> None:
"""Test that len(samples) == dp_size when sampling."""
# Create replay buffer with dp_size=3
- replay_buffer = await ReplayBuffer.options(procs=1, with_gpus=True).as_actor(
+ replay_buffer = await ReplayBuffer.options(procs=1, with_gpus=False).as_actor(
batch_size=2, max_policy_age=1, dp_size=3
)
await replay_buffer.setup.call()
# Add enough trajectories to sample
for i in range(10):
- trajectory = Trajectory(policy_version=0)
- await replay_buffer.add.call_one(trajectory)
+ episode = TestEpisode(policy_version=0)
+ await replay_buffer.add.call_one(episode)
# Sample and verify len(samples) == dp_size
samples = await replay_buffer.sample.call_one(curr_policy_version=0)
@@ -129,3 +139,16 @@ async def test_sample_dp_size(self) -> None:
assert len(dp_samples) == 2 # batch_size
replay_buffer.clear.call_one().get()
+
+ @pytest.mark.asyncio
+ async def test_collect(self) -> None:
+ """Test _collect method"""
+ local_rb = ReplayBuffer(batch_size=1)
+ await local_rb.setup._method(local_rb)
+ for i in range(1, 6):
+ local_rb.buffer.append(i)
+ values = local_rb._collect([2, 0, -1])
+ assert values == [3, 1, 5]
+ values = local_rb._collect([1, 3])
+ assert values == [2, 4]
+ assert local_rb.buffer[0] == 1
diff --git a/tests/unit_tests/test_service.py b/tests/unit_tests/test_service.py
index 31a912542..910066469 100644
--- a/tests/unit_tests/test_service.py
+++ b/tests/unit_tests/test_service.py
@@ -11,6 +11,8 @@
import asyncio
import logging
+import monarch
+
import pytest
from forge.controller import ForgeActor
from forge.controller.service import (
@@ -24,6 +26,11 @@
from forge.types import ProcessConfig
from monarch.actor import Actor, endpoint
+# Temporary workaround - without this, proc_mesh.stop
+# will raise an exit code 1 failing all other tests.
+monarch.actor.unhandled_fault_hook = lambda failure: None
+
+
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
@@ -81,7 +88,7 @@ def make_replica(idx: int, healthy: bool = True, load: int = 0) -> Replica:
@pytest.mark.asyncio
-@pytest.mark.timeout(10)
+@pytest.mark.timeout(30)
async def test_as_actor_with_args_config():
"""Test spawning a single actor with passing configs through kwargs."""
actor = await Counter.options(procs=1).as_actor(5)
@@ -98,7 +105,7 @@ async def test_as_actor_with_args_config():
@pytest.mark.asyncio
-@pytest.mark.timeout(10)
+@pytest.mark.timeout(30)
async def test_as_actor_default_usage():
"""Test spawning a single actor directly via .as_actor() using default config."""
actor = await Counter.as_actor(v=7)
@@ -115,12 +122,12 @@ async def test_as_actor_default_usage():
@pytest.mark.asyncio
-@pytest.mark.timeout(10)
+@pytest.mark.timeout(30)
async def test_options_applies_config():
"""Test config via options class."""
- actor_cls = Counter.options(procs=1, with_gpus=True, num_replicas=2)
+ actor_cls = Counter.options(procs=1, with_gpus=False, num_replicas=2)
assert actor_cls.procs == 1
- assert actor_cls.with_gpus is True
+ assert actor_cls.with_gpus is False
assert actor_cls.num_replicas == 2
actor = await actor_cls.as_actor(v=3)
@@ -133,7 +140,7 @@ async def test_options_applies_config():
# Service Config Tests
-@pytest.mark.timeout(10)
+@pytest.mark.timeout(30)
@pytest.mark.asyncio
async def test_actor_def_type_validation():
"""Test that .options() rejects classes that are not ForgeActor subclasses."""
@@ -172,12 +179,12 @@ async def test_service_with_kwargs_config():
@pytest.mark.asyncio
async def test_service_default_config():
"""Construct with default configuration using as_service directly."""
- service = await Counter.as_service(10)
+ service = await Counter.as_service(30)
try:
cfg = service._service._cfg
assert cfg.num_replicas == 1
assert cfg.procs == 1
- assert await service.value.route() == 10
+ assert await service.value.route() == 30
finally:
await service.shutdown()
@@ -188,7 +195,7 @@ async def test_multiple_services_isolated_configs():
"""Ensure multiple services from the same actor class have independent configs."""
# Create first service with 2 replicas
- service1 = await Counter.options(num_replicas=2, procs=1).as_service(v=10)
+ service1 = await Counter.options(num_replicas=2, procs=1).as_service(v=30)
# Create second service with 4 replicas
service2 = await Counter.options(num_replicas=4, procs=1).as_service(v=20)
@@ -206,7 +213,7 @@ async def test_multiple_services_isolated_configs():
val1 = await service1.value.route()
val2 = await service2.value.route()
- assert val1 == 10
+ assert val1 == 30
assert val2 == 20
finally:
@@ -253,7 +260,7 @@ async def test_service_endpoint_monarch_method_error():
# Core Functionality Tests
-@pytest.mark.timeout(10)
+@pytest.mark.timeout(30)
@pytest.mark.asyncio
async def test_basic_service_operations():
"""Test basic service creation, sessions, and endpoint calls."""
@@ -284,7 +291,7 @@ async def test_basic_service_operations():
await service.shutdown()
-@pytest.mark.timeout(10)
+@pytest.mark.timeout(30)
@pytest.mark.asyncio
async def test_sessionless_calls():
"""Test sessionless calls with round robin load balancing."""
@@ -311,7 +318,7 @@ async def test_sessionless_calls():
# Users should be able to call endpoint with just args
result = await service.add_to_value.route(5, multiplier=2)
- assert result == 11 # 1 + 10
+ assert result == 11 # 1 + 30
finally:
await service.shutdown()
@@ -482,7 +489,7 @@ async def test_replica_failure_and_recovery():
# Metrics and Monitoring Tests
-@pytest.mark.timeout(10)
+@pytest.mark.timeout(30)
@pytest.mark.asyncio
async def test_metrics_collection():
"""Test metrics collection."""
@@ -534,7 +541,7 @@ async def test_metrics_collection():
# Load Balancing and Session Management Tests
-@pytest.mark.timeout(10)
+@pytest.mark.timeout(30)
@pytest.mark.asyncio
async def test_session_stickiness():
"""Test that sessions stick to the same replica."""
@@ -564,7 +571,7 @@ async def test_session_stickiness():
await service.shutdown()
-@pytest.mark.timeout(10)
+@pytest.mark.timeout(30)
@pytest.mark.asyncio
async def test_load_balancing_multiple_sessions():
"""Test load balancing across multiple sessions using least-loaded assignment."""
@@ -612,7 +619,7 @@ async def test_load_balancing_multiple_sessions():
await service.shutdown()
-@pytest.mark.timeout(10)
+@pytest.mark.timeout(30)
@pytest.mark.asyncio
async def test_concurrent_operations():
"""Test concurrent operations across sessions and sessionless calls."""
@@ -652,7 +659,7 @@ async def test_concurrent_operations():
# `call` endpoint tests
-@pytest.mark.timeout(10)
+@pytest.mark.timeout(30)
@pytest.mark.asyncio
async def test_broadcast_call_basic():
"""Test basic broadcast call functionality."""
@@ -674,7 +681,7 @@ async def test_broadcast_call_basic():
assert isinstance(values, list)
assert len(values) == 3
- # All replicas should have incremented from 10 to 11
+ # All replicas should have incremented from 30 to 11
assert all(value == 11 for value in values)
finally:
@@ -683,7 +690,7 @@ async def test_broadcast_call_basic():
@pytest.mark.timeout(15)
@pytest.mark.asyncio
-async def test_broadcast_call_with_failed_replica():
+async def dont_test_broadcast_call_with_failed_replica():
"""Test broadcast call behavior when some replicas fail."""
service = await Counter.options(procs=1, num_replicas=3).as_service(v=0)
@@ -719,7 +726,7 @@ async def test_broadcast_call_with_failed_replica():
await service.shutdown()
-@pytest.mark.timeout(10)
+@pytest.mark.timeout(30)
@pytest.mark.asyncio
async def test_broadcast_fanout_vs_route():
"""Test that broadcast fanout hits all replicas while route hits only one."""
@@ -759,8 +766,7 @@ async def test_broadcast_fanout_vs_route():
# Router Tests
-@pytest.mark.asyncio
-async def test_session_router_with_round_robin_fallback():
+def test_session_router_with_round_robin_fallback():
"""Switch fallback router to round-robin and verify assignment order."""
# Choose RoundRobinRouter as fallback, r1 and r2 should be assigned to different replicas
replicas = [make_replica(0, load=0), make_replica(1, load=5)]
@@ -789,7 +795,7 @@ async def test_session_router_with_round_robin_fallback():
# Router integeration tests
-@pytest.mark.timeout(10)
+@pytest.mark.timeout(30)
@pytest.mark.asyncio
async def test_round_robin_router_distribution():
"""Test that the RoundRobinRouter distributes sessionless calls evenly across replicas."""
@@ -814,7 +820,7 @@ async def test_round_robin_router_distribution():
await service.shutdown()
-@pytest.mark.timeout(10)
+@pytest.mark.timeout(30)
@pytest.mark.asyncio
async def test_session_router_assigns_and_updates_session_map_in_service():
"""Integration: Service with SessionRouter preserves sticky sessions."""
diff --git a/tests/unit_tests/test_trainer.py b/tests/unit_tests/test_trainer.py
deleted file mode 100644
index a5a6f290e..000000000
--- a/tests/unit_tests/test_trainer.py
+++ /dev/null
@@ -1,102 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-import os
-import shutil
-import tempfile
-import unittest
-
-from forge.actors.trainer import cleanup_old_weight_versions
-
-
-class TestTrainerUtilities(unittest.TestCase):
- def setUp(self):
- """Set up test environment with temporary directory."""
- self.test_dir = tempfile.mkdtemp()
- self.addCleanup(shutil.rmtree, self.test_dir)
-
- def test_cleanup_old_weight_versions_basic(self):
- """Test basic cleanup functionality - keeps current and N-1 versions."""
- # Create test directory structure
- state_dict_key = os.path.join(self.test_dir, "model")
- delim = "__"
-
- # Create some mock weight directories
- old_version_1 = f"{state_dict_key}{delim}1"
- previous_version = f"{state_dict_key}{delim}2" # N-1 version
- current_version = f"{state_dict_key}{delim}3" # Current version
- unrelated_dir = os.path.join(self.test_dir, "other_model__1")
-
- for dir_path in [
- old_version_1,
- previous_version,
- current_version,
- unrelated_dir,
- ]:
- os.makedirs(dir_path)
-
- # Run cleanup for version 3
- cleanup_old_weight_versions(
- state_dict_key=state_dict_key,
- delim=delim,
- current_policy_version=3,
- )
-
- # Check that only very old versions were deleted (version 1)
- self.assertFalse(os.path.exists(old_version_1))
-
- # Check that current and previous versions still exist
- self.assertTrue(os.path.exists(previous_version)) # N-1 version should remain
- self.assertTrue(
- os.path.exists(current_version)
- ) # Current version should remain
- self.assertTrue(os.path.exists(unrelated_dir)) # Unrelated dirs should remain
-
- def test_cleanup_old_weight_versions_no_cleanup_version_1(self):
- """Test that no cleanup happens when current_policy_version <= 1."""
- # Create test directory structure
- state_dict_key = os.path.join(self.test_dir, "model")
- delim = "__"
-
- version_1 = f"{state_dict_key}{delim}1"
- os.makedirs(version_1)
-
- # Run cleanup for version 1 - should do nothing
- cleanup_old_weight_versions(
- state_dict_key=state_dict_key,
- delim=delim,
- current_policy_version=1,
- )
-
- # Version 1 should still exist
- self.assertTrue(os.path.exists(version_1))
-
- def test_cleanup_old_weight_versions_version_2(self):
- """Test cleanup with version 2 as current - should keep versions 1 and 2."""
- # Create test directory structure
- state_dict_key = os.path.join(self.test_dir, "model")
- delim = "__"
-
- version_1 = f"{state_dict_key}{delim}1" # N-1 version
- version_2 = f"{state_dict_key}{delim}2" # Current version
-
- for dir_path in [version_1, version_2]:
- os.makedirs(dir_path)
-
- # Run cleanup for version 2
- cleanup_old_weight_versions(
- state_dict_key=state_dict_key,
- delim=delim,
- current_policy_version=2,
- )
-
- # Both versions should still exist (no deletion for version 2)
- self.assertTrue(os.path.exists(version_1))
- self.assertTrue(os.path.exists(version_2))
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/unit_tests/util/test_compute_logprobs.py b/tests/unit_tests/util/test_compute_logprobs.py
deleted file mode 100644
index c4e3bffcb..000000000
--- a/tests/unit_tests/util/test_compute_logprobs.py
+++ /dev/null
@@ -1,111 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-import pytest
-import torch
-import torch.nn.functional as F
-from forge.util.ops import compute_logprobs
-
-
-def _textbook_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor):
- # Helper: Textbook Log Softmax
- log_probs = F.log_softmax(logits, dim=-1)
- return torch.gather(log_probs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
-
-
-class TestComputeLogprobs:
- def test_single_batch_item(self):
- """Test with single batch item."""
- # Shape: (1, 2, 3)
- logits = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
- # Shape: (1, 1)
- input_ids = torch.tensor([[1]])
- result = compute_logprobs(logits, input_ids)
-
- # Manual calculation
- expected_logits = torch.tensor([[[1.0, 2.0, 3.0]]])
- expected = _textbook_log_softmax(expected_logits, input_ids)
-
- assert torch.allclose(result, expected, atol=1e-5)
- assert result.shape == (1, 1)
-
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
- # Shape: (1, 3, 3)
- logits = torch.tensor([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]])
- # Shape: (1, 2)
- input_ids = torch.tensor([[2, 0]])
- result = compute_logprobs(logits, input_ids)
-
- # Manual calculation
- expected_logits = torch.tensor([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]])
- expected = _textbook_log_softmax(expected_logits, input_ids)
-
- assert torch.allclose(result, expected, atol=1e-5)
- assert result.shape == (1, 2)
-
- @pytest.mark.timeout(10)
- def test_multi_batch(self):
- """Test with multiple batch items."""
- # Shape: (2, 2, 3)
- logits = torch.tensor(
- [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[0.5, 1.5, 2.5], [3.5, 4.5, 5.5]]]
- )
- # Shape: (2, 1)
- input_ids = torch.tensor([[1], [2]])
- result = compute_logprobs(logits, input_ids)
-
- # Manual calculation
- expected_logits = torch.tensor([[[1.0, 2.0, 3.0]], [[0.5, 1.5, 2.5]]])
- expected = _textbook_log_softmax(expected_logits, input_ids)
-
- assert torch.allclose(result, expected, atol=1e-5)
- assert result.shape == (2, 1)
-
- @pytest.mark.timeout(10)
- def test_temperature(self):
- """Test with different temperature values."""
- batch_size, seq_len, vocab_size = 2, 4, 6
- logits = torch.randn(batch_size, seq_len, vocab_size)
- input_ids = torch.randint(0, vocab_size, (batch_size, seq_len - 1))
-
- # Manual calculation with temperature scaling
- def _manual(temperature: float):
- expected_logits = logits[:, 0:-1] / temperature
- return _textbook_log_softmax(expected_logits, input_ids)
-
- temperatures = [1.0, 2.0, 4.5]
- for temperature in temperatures:
- result = compute_logprobs(logits, input_ids, temperature=temperature)
- expected = _manual(temperature)
- assert torch.allclose(result, expected, atol=1e-5)
- assert result.shape == input_ids.shape
-
- @pytest.mark.timeout(10)
- def test_edge_cases(self):
- """Test edge cases."""
- # Test with very large values (numerical stability)
- logits = torch.tensor([[[1000.0, 2000.0], [1500.0, 2500.0]]])
- input_ids = torch.tensor([[0]])
- result = compute_logprobs(logits, input_ids)
- # Should not be NaN or inf
- assert torch.isfinite(result).all()
-
- # Test with very small values
- logits = torch.tensor([[[-1000.0, -2000.0], [-1500.0, -2500.0]]])
- input_ids = torch.tensor([[1]])
- result = compute_logprobs(logits, input_ids)
- # Should not be NaN or inf
- assert torch.isfinite(result).all()
-
- def test_compute_logprobs_empty_response(self):
- """Test logprobs computation with empty response."""
- batch_size, seq_len, vocab_size = 1, 5, 1000
- logits = torch.randn(batch_size, seq_len, vocab_size)
- input_ids = torch.tensor([[]])
-
- result = compute_logprobs(logits, input_ids)
- assert result.shape == (batch_size, 0)
diff --git a/tests/unit_tests/util/test_ops.py b/tests/unit_tests/util/test_ops.py
new file mode 100644
index 000000000..b9e929120
--- /dev/null
+++ b/tests/unit_tests/util/test_ops.py
@@ -0,0 +1,229 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import pytest
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+from forge.util.ops import compute_logprobs
+
+from tests.test_utils import gpu_test
+from torch.distributed.device_mesh import init_device_mesh
+from torch.distributed.tensor import DTensor, Shard
+from torch.distributed.tensor.parallel import loss_parallel
+from torch.testing._internal.common_fsdp import FSDPTest
+
+
+def _textbook_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor):
+ # Helper: Textbook Log Softmax
+ log_probs = F.log_softmax(logits, dim=-1)
+ return torch.gather(log_probs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
+
+
+class TestComputeLogprobs:
+ def test_single_batch_item(self):
+ """Test with single batch item."""
+ # Shape: (1, 2, 3)
+ logits = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
+ # Shape: (1, 1)
+ input_ids = torch.tensor([[1]])
+ result = compute_logprobs(logits, input_ids)
+
+ # Manual calculation
+ expected_logits = torch.tensor([[[1.0, 2.0, 3.0]]])
+ expected = _textbook_log_softmax(expected_logits, input_ids)
+
+ assert torch.allclose(result, expected, atol=1e-5)
+ assert result.shape == (1, 1)
+
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ # Shape: (1, 3, 3)
+ logits = torch.tensor([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]])
+ # Shape: (1, 2)
+ input_ids = torch.tensor([[2, 0]])
+ result = compute_logprobs(logits, input_ids)
+
+ # Manual calculation
+ expected_logits = torch.tensor([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]])
+ expected = _textbook_log_softmax(expected_logits, input_ids)
+
+ assert torch.allclose(result, expected, atol=1e-5)
+ assert result.shape == (1, 2)
+
+ @pytest.mark.timeout(10)
+ def test_multi_batch(self):
+ """Test with multiple batch items."""
+ # Shape: (2, 2, 3)
+ logits = torch.tensor(
+ [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[0.5, 1.5, 2.5], [3.5, 4.5, 5.5]]]
+ )
+ # Shape: (2, 1)
+ input_ids = torch.tensor([[1], [2]])
+ result = compute_logprobs(logits, input_ids)
+
+ # Manual calculation
+ expected_logits = torch.tensor([[[1.0, 2.0, 3.0]], [[0.5, 1.5, 2.5]]])
+ expected = _textbook_log_softmax(expected_logits, input_ids)
+
+ assert torch.allclose(result, expected, atol=1e-5)
+ assert result.shape == (2, 1)
+
+ @pytest.mark.timeout(10)
+ def test_temperature(self):
+ """Test with different temperature values."""
+ batch_size, seq_len, vocab_size = 2, 4, 6
+ logits = torch.randn(batch_size, seq_len, vocab_size)
+ input_ids = torch.randint(0, vocab_size, (batch_size, seq_len - 1))
+
+ # Manual calculation with temperature scaling
+ def _manual(temperature: float):
+ expected_logits = logits[:, 0:-1] / temperature
+ return _textbook_log_softmax(expected_logits, input_ids)
+
+ temperatures = [1.0, 2.0, 4.5]
+ for temperature in temperatures:
+ result = compute_logprobs(logits, input_ids, temperature=temperature)
+ expected = _manual(temperature)
+ assert torch.allclose(result, expected, atol=1e-5)
+ assert result.shape == input_ids.shape
+
+ @pytest.mark.timeout(10)
+ def test_edge_cases(self):
+ """Test edge cases."""
+ # Test with very large values (numerical stability)
+ logits = torch.tensor([[[1000.0, 2000.0], [1500.0, 2500.0]]])
+ input_ids = torch.tensor([[0]])
+ result = compute_logprobs(logits, input_ids)
+ # Should not be NaN or inf
+ assert torch.isfinite(result).all()
+
+ # Test with very small values
+ logits = torch.tensor([[[-1000.0, -2000.0], [-1500.0, -2500.0]]])
+ input_ids = torch.tensor([[1]])
+ result = compute_logprobs(logits, input_ids)
+ # Should not be NaN or inf
+ assert torch.isfinite(result).all()
+
+ def test_compute_logprobs_empty_response(self):
+ """Test logprobs computation with empty response."""
+ batch_size, seq_len, vocab_size = 1, 5, 1000
+ logits = torch.randn(batch_size, seq_len, vocab_size)
+ input_ids = torch.tensor([[]])
+
+ result = compute_logprobs(logits, input_ids)
+ assert result.shape == (batch_size, 0)
+
+ @pytest.mark.timeout(10)
+ def test_align_parameter_false(self):
+ """Test with align=False (pre-aligned logits)."""
+ # When align=False, logits are already aligned with input_ids
+ # logits[:, i] predicts input_ids[:, i]
+ batch_size, seq_len, vocab_size = 2, 3, 5
+ logits = torch.randn(batch_size, seq_len, vocab_size)
+ input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
+
+ result = compute_logprobs(logits, input_ids, align=False)
+
+ # Manual calculation without slicing
+ expected = _textbook_log_softmax(logits, input_ids)
+
+ assert torch.allclose(result, expected, atol=1e-5)
+ assert result.shape == input_ids.shape
+
+ @pytest.mark.timeout(10)
+ def test_align_parameter_true(self):
+ """Test with align=True (default, needs slicing)."""
+ # When align=True, logits need to be sliced to align with input_ids
+ batch_size, full_seq_len, vocab_size = 2, 6, 5
+ logits = torch.randn(batch_size, full_seq_len, vocab_size)
+
+ # We want log probs for just the last 3 tokens
+ target_len = 3
+ input_ids = torch.randint(0, vocab_size, (batch_size, target_len))
+
+ result = compute_logprobs(logits, input_ids, align=True)
+
+ # Manual calculation: align=True slices logits[:, -target_len-1:-1]
+ sliced_logits = logits[:, -target_len - 1 : -1, :]
+ expected = _textbook_log_softmax(sliced_logits, input_ids)
+
+ assert torch.allclose(result, expected, atol=1e-5)
+ assert result.shape == input_ids.shape
+
+ @pytest.mark.timeout(10)
+ def test_align_comparison(self):
+ """Test that align=True properly slices logits."""
+ batch_size, seq_len, vocab_size = 1, 4, 10
+ logits = torch.randn(batch_size, seq_len, vocab_size)
+ input_ids = torch.randint(0, vocab_size, (batch_size, 2))
+
+ result_aligned = compute_logprobs(logits, input_ids, align=True)
+
+ # Manually slice the same way align=True does
+ sliced_logits = logits[:, -input_ids.size(1) - 1 : -1, :]
+ result_manual = compute_logprobs(sliced_logits, input_ids, align=False)
+
+ # Both should give the same result
+ assert torch.allclose(result_aligned, result_manual, atol=1e-5)
+
+
+class TestComputeLogprobsWithLossParallel(FSDPTest):
+ """Test compute_logprobs with loss_parallel context for vocab-sharded DTensors."""
+
+ @property
+ def world_size(self) -> int:
+ return 2
+
+ @gpu_test(gpu_count=2)
+ def test_loss_parallel_matches_sequential(self):
+ """Verify compute_logprobs under loss_parallel matches non-sharded version."""
+ torch.manual_seed(42)
+
+ batch_size, seq_len, vocab_size, target_len = 4, 16, 1000, 8
+ rank = dist.get_rank()
+ device = torch.device(f"cuda:{rank}")
+
+ # Create and broadcast test data
+ if rank == 0:
+ full_logits = torch.randn(batch_size, seq_len, vocab_size, device=device)
+ target_ids = torch.randint(
+ 0, vocab_size, (batch_size, target_len), device=device
+ )
+ else:
+ full_logits = torch.empty(batch_size, seq_len, vocab_size, device=device)
+ target_ids = torch.empty(
+ batch_size, target_len, dtype=torch.int64, device=device
+ )
+
+ dist.broadcast(full_logits, src=0)
+ dist.broadcast(target_ids, src=0)
+
+ # Reference: non-sharded computation
+ expected = compute_logprobs(full_logits, target_ids, align=True)
+
+ # Create vocab-sharded DTensor
+ mesh = init_device_mesh("cuda", (self.world_size,), mesh_dim_names=("tp",))
+ local_vocab = vocab_size // self.world_size
+ dtensor_logits = DTensor.from_local(
+ full_logits[:, :, rank * local_vocab : (rank + 1) * local_vocab],
+ mesh,
+ placements=[Shard(2)],
+ )
+
+ # Compute with loss_parallel context
+ with loss_parallel():
+ result = compute_logprobs(dtensor_logits, target_ids, align=True)
+
+ # Verify output is Replicated as expected from loss_parallel
+ assert isinstance(result, DTensor)
+ assert result.placements[
+ 0
+ ].is_replicate(), f"Expected Replicated placement, got {result.placements}"
+ result = result.to_local()
+
+ torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5)
diff --git a/tests/unit_tests/util/test_selective_log_softmax.py b/tests/unit_tests/util/test_selective_log_softmax.py
deleted file mode 100644
index 4ca94f2c3..000000000
--- a/tests/unit_tests/util/test_selective_log_softmax.py
+++ /dev/null
@@ -1,92 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-import pytest
-import torch
-import torch.nn.functional as F
-from forge.util.ops import selective_log_softmax
-
-
-class TestSelectiveLogSoftmax:
- @pytest.mark.timeout(10)
- def test_basic_2d(self):
- """Test basic 2D case."""
- logits = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
- index = torch.tensor([0, 2]) # Select positions 0 and 2
- result = selective_log_softmax(logits, index)
- # Compare with torch's implementation
- expected = torch.gather(
- F.log_softmax(logits, dim=-1), dim=-1, index=index.unsqueeze(-1)
- ).squeeze(-1)
- assert torch.allclose(result, expected, atol=1e-5)
- assert result.shape == (2,) # Same shape as index
-
- @pytest.mark.timeout(10)
- def test_single_row(self):
- """Test with single row."""
- logits = torch.tensor([[1.0, 2.0, 3.0]])
- index = torch.tensor([1]) # Select middle element
- result = selective_log_softmax(logits, index)
- # Manual calculation: log_softmax then select index 1
- log_probs = F.log_softmax(logits, dim=-1)
- expected = log_probs[0, 1]
- assert torch.allclose(result, expected)
- assert result.shape == (1,)
-
- @pytest.mark.timeout(10)
- def test_different_dtypes(self):
- """Test with different data types."""
- logits_f32 = torch.randn(2, 4, dtype=torch.float32)
- logits_bf16 = torch.randn(2, 4, dtype=torch.bfloat16)
- index = torch.tensor([0, 3])
- result_f32 = selective_log_softmax(logits_f32, index)
- result_bf16 = selective_log_softmax(logits_bf16, index)
- # Check output dtypes match input dtypes
- assert result_f32.dtype == torch.float32
- assert result_bf16.dtype == torch.bfloat16
- # Check shapes
- assert result_f32.shape == (2,)
- assert result_bf16.shape == (2,)
-
- @pytest.mark.timeout(10)
- def test_3d_tensor(self):
- """Test with 3D tensor."""
- batch, seq, vocab = 2, 3, 5
- logits = torch.randn(batch, seq, vocab)
- index = torch.randint(0, vocab, (batch, seq))
- result = selective_log_softmax(logits, index)
- # Should have same shape as index
- assert result.shape == (batch, seq)
- # All values should be negative (log probabilities)
- assert (result <= 0).all()
-
- @pytest.mark.timeout(10)
- def test_known_values(self):
- """Test with known values for manual verification."""
- # Simple case where we can calculate by hand
- logits = torch.tensor([[0.0, 0.0]]) # Equal logits
- index = torch.tensor([0])
- result = selective_log_softmax(logits, index)
- # log_softmax of [0, 0] gives [-log(2), -log(2)]
- # Selecting index 0 should give -log(2)
- expected = -torch.log(torch.tensor(2.0))
- assert torch.allclose(result, expected, atol=1e-6)
-
- @pytest.mark.timeout(10)
- def test_edge_cases(self):
- """Test edge cases."""
- # Test with single class
- logits = torch.tensor([[5.0]])
- index = torch.tensor([0])
- result = selective_log_softmax(logits, index)
- # log_softmax of single element is 0
- assert torch.allclose(result, torch.tensor([0.0]))
- # Test with large values (numerical stability)
- logits = torch.tensor([[100.0, 200.0]])
- index = torch.tensor([1])
- result = selective_log_softmax(logits, index)
- # Should not be NaN or inf
- assert torch.isfinite(result).all()
diff --git a/tests/unit_tests/util/test_shared_tensor.py b/tests/unit_tests/util/test_shared_tensor.py
new file mode 100644
index 000000000..f922c3733
--- /dev/null
+++ b/tests/unit_tests/util/test_shared_tensor.py
@@ -0,0 +1,905 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import pickle
+import time
+
+import pytest
+import torch
+
+# Assuming SharedTensor is in shared_tensor.py
+from forge.util._shared_tensor import SharedTensor
+from multiprocess import Process, Queue
+
+
+class TestSharedTensorCreation:
+ """Test tensor creation methods"""
+
+ def test_empty_creation(self):
+ """Test creating empty tensor"""
+ shape = (100, 200)
+ dtype = torch.float32
+
+ shared = SharedTensor.empty(shape, dtype)
+
+ assert shared.tensor.shape == torch.Size(shape)
+ assert shared.tensor.dtype == dtype
+ assert shared.tensor.shape == torch.Size(shape)
+ assert shared.tensor.dtype == dtype
+
+ shared.drop()
+
+ def test_empty_with_bfloat16(self):
+ """Test creating empty bfloat16 tensor"""
+ shape = (50, 50)
+ shared = SharedTensor.empty(shape, torch.bfloat16)
+
+ assert shared.tensor.dtype == torch.bfloat16
+ assert shared.tensor.dtype == torch.bfloat16
+
+ shared.drop()
+
+ def test_zeros_creation(self):
+ """Test creating zero-initialized tensor"""
+ shape = (10, 20)
+ shared = SharedTensor.zeros(shape, torch.float32)
+
+ tensor = shared.tensor
+ assert torch.all(tensor == 0)
+ assert tensor.sum().item() == 0.0
+
+ shared.drop()
+
+ def test_ones_creation(self):
+ """Test creating ones-initialized tensor"""
+ shape = (10, 20)
+ shared = SharedTensor.ones(shape, torch.float32)
+
+ tensor = shared.tensor
+ assert torch.all(tensor == 1)
+ assert tensor.sum().item() == 200.0
+
+ shared.drop()
+
+ def test_from_tensor_creation(self):
+ """Test creating from existing tensor"""
+ original = torch.randn(50, 50)
+ shared = SharedTensor(tensor=original)
+
+ assert shared.tensor.shape == original.shape
+ assert shared.tensor.dtype == original.dtype
+ assert torch.allclose(shared.tensor, original)
+
+ shared.drop()
+
+ def test_from_handle_creation(self):
+ """Test creating from handle"""
+ # Create original
+ original = SharedTensor.empty((10, 10), torch.float32)
+ original.tensor.fill_(5.0)
+
+ # Get handle
+ handle = original.get_handle()
+
+ # Create from handle
+ reconstructed = SharedTensor(handle=handle)
+
+ assert torch.all(reconstructed.tensor == 5.0)
+ assert reconstructed.tensor.shape == original.tensor.shape
+ assert reconstructed.tensor.dtype == original.tensor.dtype
+
+ original.drop()
+
+ def test_creation_requires_argument(self):
+ """Test that creation without arguments raises error"""
+ with pytest.raises(ValueError, match="Must provide either tensor or handle"):
+ SharedTensor()
+
+ @pytest.mark.parametrize(
+ "shape",
+ [
+ (10,),
+ (10, 20),
+ (5, 10, 15),
+ (2, 3, 4, 5),
+ ],
+ )
+ def test_various_shapes(self, shape):
+ """Test creation with various shapes"""
+ shared = SharedTensor.empty(shape, torch.float32)
+ assert shared.tensor.shape == torch.Size(shape)
+ assert shared.tensor.shape == torch.Size(shape)
+ shared.drop()
+
+
+class TestSharedTensorDtypes:
+ """Test all supported dtypes"""
+
+ @pytest.mark.parametrize(
+ "dtype",
+ [
+ torch.float32,
+ torch.float64,
+ torch.float16,
+ torch.bfloat16,
+ torch.int32,
+ torch.int64,
+ torch.int16,
+ torch.int8,
+ torch.uint8,
+ torch.bool,
+ ],
+ )
+ def test_all_dtypes(self, dtype):
+ """Test that all dtypes work correctly"""
+ shape = (10, 10)
+ shared = SharedTensor.empty(shape, dtype)
+
+ assert shared.tensor.dtype == dtype
+ assert shared.tensor.dtype == dtype
+
+ # Test that we can write to it
+ if dtype == torch.bool:
+ shared.tensor.fill_(True)
+ elif dtype in [torch.int32, torch.int64, torch.int16, torch.int8, torch.uint8]:
+ shared.tensor.fill_(42)
+ else:
+ shared.tensor.fill_(3.14)
+
+ shared.drop()
+
+ def test_dtype_conversion_in_handle(self):
+ """Test dtype is preserved through handle"""
+ for dtype in [torch.float32, torch.bfloat16, torch.int64]:
+ shared1 = SharedTensor.empty((5, 5), dtype)
+ handle = shared1.get_handle()
+
+ shared2 = SharedTensor(handle=handle)
+ assert shared2.tensor.dtype == dtype
+
+ shared1.drop()
+
+
+class TestSharedTensorOperations:
+ """Test tensor operations"""
+
+ def test_copy_from(self):
+ """Test copying data from another tensor"""
+ source = torch.randn(20, 30)
+ shared = SharedTensor.empty((20, 30), torch.float32)
+
+ shared.copy_from(source)
+
+ assert torch.allclose(shared.tensor, source)
+ shared.drop()
+
+ def test_copy_from_shape_mismatch(self):
+ """Test copy_from raises error on shape mismatch"""
+ source = torch.randn(10, 10)
+ shared = SharedTensor.empty((20, 20), torch.float32)
+
+ with pytest.raises(ValueError, match="Shape mismatch"):
+ shared.copy_from(source)
+
+ shared.drop()
+
+ def test_clone(self):
+ """Test cloning creates independent copy"""
+ original = SharedTensor.empty((10, 10), torch.float32)
+ original.tensor.fill_(5.0)
+
+ cloned = original.clone()
+
+ # Verify data is same
+ assert torch.all(cloned.tensor == 5.0)
+
+ # Verify they're independent
+ original.tensor.fill_(10.0)
+ assert torch.all(cloned.tensor == 5.0)
+ assert torch.all(original.tensor == 10.0)
+
+ original.drop()
+ cloned.drop()
+
+ def test_tensor_modifications(self):
+ """Test that modifications to tensor are reflected"""
+ shared = SharedTensor.zeros((10, 10), torch.float32)
+ tensor = shared.tensor
+
+ tensor[0, 0] = 99.0
+ tensor[5:, :] = 42.0
+
+ # Get tensor again and verify changes persist
+ tensor2 = shared.tensor
+ assert tensor2[0, 0].item() == 99.0
+ assert torch.all(tensor2[5:, :] == 42.0)
+
+ shared.drop()
+
+ def test_inplace_operations(self):
+ """Test in-place operations work"""
+ shared = SharedTensor.empty((100, 100), torch.float32)
+ tensor = shared.tensor
+
+ tensor.normal_(0, 1)
+ mean = tensor.mean().item()
+
+ tensor.add_(5.0)
+ new_mean = tensor.mean().item()
+
+ assert abs(new_mean - (mean + 5.0)) < 0.1
+
+ shared.drop()
+
+
+class TestSharedTensorSerialization:
+ """Test pickling and handle serialization"""
+
+ def test_handle_is_picklable(self):
+ """Test that handle can be pickled"""
+ shared = SharedTensor.empty((10, 10), torch.float32)
+ handle = shared.get_handle()
+
+ # Pickle and unpickle
+ pickled = pickle.dumps(handle)
+ unpickled_handle = pickle.loads(pickled)
+
+ assert unpickled_handle == handle
+
+ shared.drop()
+
+ def test_handle_small_size(self):
+ """Test that handle is small (efficient for RPC)"""
+ shared = SharedTensor.empty((10000, 10000), torch.float32)
+ handle = shared.get_handle()
+
+ pickled = pickle.dumps(handle)
+
+ # Handle should be < 1KB even for huge tensors
+ assert len(pickled) < 1024
+
+ shared.drop()
+
+ def test_data_integrity_after_pickle(self):
+ """Test data is preserved through handle pickling"""
+ # Create and fill tensor
+ shared1 = SharedTensor.empty((50, 50), torch.bfloat16)
+ shared1.tensor.normal_(0, 1)
+ original_data = shared1.tensor.clone()
+
+ # Pickle handle
+ handle = shared1.get_handle()
+ pickled = pickle.dumps(handle)
+ unpickled_handle = pickle.loads(pickled)
+
+ # Reconstruct
+ shared2 = SharedTensor(handle=unpickled_handle)
+
+ # Verify data is same
+ assert torch.allclose(shared2.tensor.float(), original_data.float(), rtol=1e-3)
+
+ shared1.drop()
+
+
+class TestSharedTensorMemory:
+ """Test memory management and cleanup"""
+
+ def test_drop(self):
+ """Test drop removes shared memory"""
+ shared = SharedTensor.empty((10, 10), torch.float32)
+ shm_name = shared._shm_name
+
+ # Verify shared memory exists
+ tensor = shared.tensor
+ tensor.fill_(5.0)
+
+ # Drop shared memory
+ shared.drop()
+
+ # Trying to attach should fail
+ from multiprocessing import shared_memory
+
+ with pytest.raises(FileNotFoundError):
+ shared_memory.SharedMemory(name=shm_name)
+
+ def test_multiple_views_same_memory(self):
+ """Test multiple tensor views point to same memory"""
+ shared = SharedTensor.empty((10, 10), torch.float32)
+
+ tensor1 = shared.tensor
+ tensor1.fill_(5.0)
+
+ tensor2 = shared.tensor
+ assert torch.all(tensor2 == 5.0)
+
+ # Modify through tensor2
+ tensor2.fill_(10.0)
+
+ # Verify tensor1 sees the change
+ assert torch.all(tensor1 == 10.0)
+
+ shared.drop()
+
+ def test_handle_reconstruction_shares_memory(self):
+ """Test that handle reconstruction shares same memory"""
+ shared1 = SharedTensor.empty((20, 20), torch.float32)
+ shared1.tensor.fill_(7.0)
+
+ handle = shared1.get_handle()
+ shared2 = SharedTensor(handle=handle)
+
+ # Modify through shared2
+ shared2.tensor.fill_(14.0)
+
+ # Verify shared1 sees the change
+ assert torch.all(shared1.tensor == 14.0)
+
+ shared1.drop()
+
+
+class TestSharedTensorEdgeCases:
+ """Test edge cases and error conditions"""
+
+ def test_empty_shape(self):
+ """Test scalar tensor (empty shape)"""
+ shared = SharedTensor.ones((), torch.float32)
+ assert shared.tensor.shape == ()
+ assert shared.tensor.numel() == 1
+ assert torch.allclose(
+ shared.tensor,
+ torch.ones(
+ (),
+ ),
+ )
+ shared.drop()
+
+ def test_single_element_tensor(self):
+ """Test 1-element tensor"""
+ shared = SharedTensor.empty((1,), torch.float32)
+ shared.tensor.fill_(42.0)
+ assert shared.tensor.item() == 42.0
+ shared.drop()
+
+ def test_large_tensor(self):
+ """Test large tensor (1GB)"""
+ # 1GB tensor: 250M float32 elements
+ shape = (250_000_000,)
+ shared = SharedTensor.empty(shape, torch.float32)
+
+ assert shared.tensor.shape == shape
+ assert shared.tensor.numel() == 250_000_000
+
+ shared.drop()
+
+ def test_non_contiguous_tensor_conversion(self):
+ """Test that non-contiguous tensors are handled"""
+ # Create non-contiguous tensor
+ original = torch.randn(10, 10).t() # Transpose makes it non-contiguous
+ assert not original.is_contiguous()
+
+ # Should work (internally makes contiguous)
+ shared = SharedTensor(tensor=original)
+
+ # Result should match
+ assert torch.allclose(shared.tensor, original)
+
+ shared.drop()
+
+ def test_repr(self):
+ """Test string representation"""
+ shared = SharedTensor.empty((10, 20), torch.float32)
+ repr_str = repr(shared)
+
+ assert "SharedTensor" in repr_str
+ assert "10, 20" in repr_str
+ assert "float32" in repr_str
+ assert shared._shm_name in repr_str
+
+ shared.drop()
+
+
+class TestSharedTensorMultiprocess:
+ """Test multiprocess scenarios"""
+
+ def test_multiprocess_read(self):
+ """Test reading shared tensor from another process"""
+
+ def reader_process(handle_dict, result_queue):
+ with SharedTensor(handle=handle_dict) as shared:
+ result_queue.put(shared.tensor.sum().item())
+
+ # Create shared tensor in main process
+ shared = SharedTensor.empty((100, 100), torch.float32)
+ shared.tensor.fill_(5.0)
+
+ # Read from child process
+ result_queue = Queue()
+ handle = shared.get_handle()
+
+ p = Process(target=reader_process, args=(handle, result_queue))
+ p.start()
+ p.join()
+
+ result = result_queue.get()
+ expected = 5.0 * 100 * 100
+
+ assert abs(result - expected) < 1e-5
+
+ shared.drop()
+
+ def test_multiprocess_write(self):
+ """Test writing to shared tensor from another process"""
+
+ def writer_process(handle_dict, value):
+ with SharedTensor(handle=handle_dict) as shared:
+ shared.tensor.fill_(value)
+
+ # Create empty shared tensor
+ shared = SharedTensor.empty((50, 50), torch.float32)
+ shared.tensor.zero_()
+
+ # Write from child process
+ handle = shared.get_handle()
+
+ p = Process(target=writer_process, args=(handle, 42.0))
+ p.start()
+ p.join()
+
+ # Verify in main process
+ assert torch.all(shared.tensor == 42.0)
+
+ shared.drop()
+
+ def test_multiprocess_bidirectional(self):
+ """Test bidirectional communication"""
+
+ def worker_process(input_handle, output_handle):
+ with SharedTensor(handle=input_handle) as input_shared:
+ with SharedTensor(handle=output_handle) as output_shared:
+ # Compute: output = input * 2
+ output_shared.tensor.copy_(input_shared.tensor * 2)
+
+ # Create input and output tensors
+ input_shared = SharedTensor.empty((100, 100), torch.float32)
+ input_shared.tensor.normal_(0, 1)
+ input_data = input_shared.tensor.clone()
+
+ output_shared = SharedTensor.empty((100, 100), torch.float32)
+
+ # Process in child
+ p = Process(
+ target=worker_process,
+ args=(input_shared.get_handle(), output_shared.get_handle()),
+ )
+ p.start()
+ p.join()
+
+ # Verify result
+ expected = input_data * 2
+ assert torch.allclose(
+ output_shared.tensor, expected
+ ), "output: {}, expected: {}".format(output_shared.tensor, expected)
+
+ input_shared.drop()
+ output_shared.drop()
+
+
+class TestSharedTensorPerformance:
+ """Performance-related tests"""
+
+ def test_empty_faster_than_from_tensor(self):
+ """Test that empty() is faster than from tensor"""
+ shape = (1000, 1000)
+
+ # Time empty creation
+ start = time.time()
+ for _ in range(10):
+ shared = SharedTensor.empty(shape, torch.float32)
+ shared.drop()
+ empty_time = time.time() - start
+
+ # Time from_tensor creation
+ start = time.time()
+ for _ in range(10):
+ tensor = torch.randn(shape)
+ shared = SharedTensor(tensor=tensor)
+ shared.drop()
+ from_tensor_time = time.time() - start
+
+ # empty() should be faster (no data copying)
+ assert empty_time < from_tensor_time
+
+ def test_handle_serialization_fast(self):
+ """Test that handle serialization is fast"""
+ shared = SharedTensor.empty((10000, 10000), torch.float32)
+ handle = shared.get_handle()
+
+ start = time.time()
+ for _ in range(1000):
+ pickled = pickle.dumps(handle)
+ unpickled = pickle.loads(pickled)
+ elapsed = time.time() - start
+
+ # Should be able to do 1000 round trips in < 0.1 seconds
+ assert elapsed < 0.1
+
+ shared.drop()
+
+
+class TestSharedTensorHandleToSharedTensor:
+ """Test SharedTensorHandle.to_shared_tensor() method"""
+
+ def test_to_shared_tensor_basic(self):
+ """Test basic creation of SharedTensor from handle using to_shared_tensor method"""
+ original = SharedTensor.empty((10, 10), torch.float32)
+ original.tensor.fill_(7.0)
+
+ handle = original.get_handle()
+ reconstructed = handle.to_shared_tensor()
+
+ assert torch.all(reconstructed.tensor == 7.0)
+ assert reconstructed.tensor.shape == original.tensor.shape
+ assert reconstructed.tensor.dtype == original.tensor.dtype
+
+ original.drop()
+
+ def test_to_shared_tensor_preserves_data(self):
+ """Test that to_shared_tensor preserves original data"""
+ original = SharedTensor.empty((20, 30), torch.float32)
+ original.tensor.normal_(0, 1)
+ original_data = original.tensor.clone()
+
+ handle = original.get_handle()
+ reconstructed = handle.to_shared_tensor()
+
+ assert torch.allclose(reconstructed.tensor, original_data)
+
+ original.drop()
+
+ def test_to_shared_tensor_shares_memory(self):
+ """Test that to_shared_tensor shares memory with original"""
+ original = SharedTensor.empty((15, 15), torch.float32)
+ original.tensor.fill_(5.0)
+
+ handle = original.get_handle()
+ reconstructed = handle.to_shared_tensor()
+
+ reconstructed.tensor.fill_(10.0)
+
+ assert torch.all(original.tensor == 10.0)
+
+ original.drop()
+
+ def test_to_shared_tensor_with_various_dtypes(self):
+ """Test to_shared_tensor works with different data types"""
+ for dtype in [torch.float32, torch.float64, torch.bfloat16, torch.int32]:
+ original = SharedTensor.empty((5, 5), dtype)
+ if (
+ dtype == torch.bfloat16
+ or dtype == torch.float32
+ or dtype == torch.float64
+ ):
+ original.tensor.normal_(0, 1)
+ else:
+ original.tensor.fill_(42)
+
+ handle = original.get_handle()
+ reconstructed = handle.to_shared_tensor()
+
+ assert reconstructed.tensor.dtype == dtype
+ if dtype == torch.bfloat16:
+ assert torch.allclose(
+ reconstructed.tensor.float(), original.tensor.float(), rtol=1e-3
+ )
+ else:
+ assert torch.allclose(reconstructed.tensor, original.tensor)
+
+ original.drop()
+
+ def test_to_shared_tensor_multiprocess(self):
+ """Test to_shared_tensor in multiprocess scenario"""
+
+ def worker_process(handle, result_queue):
+ with handle.to_shared_tensor() as shared:
+ result_queue.put(shared.tensor.sum().item())
+
+ original = SharedTensor.empty((50, 50), torch.float32)
+ original.tensor.fill_(3.0)
+
+ handle = original.get_handle()
+ result_queue = Queue()
+
+ p = Process(target=worker_process, args=(handle, result_queue))
+ p.start()
+ p.join()
+
+ result = result_queue.get()
+ expected = 3.0 * 50 * 50
+
+ assert abs(result - expected) < 1e-5
+
+ original.drop()
+
+ def test_to_shared_tensor_equivalent_to_constructor(self):
+ """Test that handle.to_shared_tensor() is equivalent to SharedTensor(handle=handle)"""
+ original = SharedTensor.empty((25, 25), torch.float32)
+ original.tensor.normal_(0, 1)
+
+ handle = original.get_handle()
+
+ via_method = handle.to_shared_tensor()
+ via_constructor = SharedTensor(handle=handle)
+
+ assert torch.allclose(via_method.tensor, via_constructor.tensor)
+ assert via_method.tensor.shape == via_constructor.tensor.shape
+ assert via_method.tensor.dtype == via_constructor.tensor.dtype
+
+ original.drop()
+
+
+class TestSharedTensorBfloat16:
+ """Specific tests for bfloat16 support"""
+
+ def test_bfloat16_creation(self):
+ """Test bfloat16 tensor creation"""
+ shared = SharedTensor.empty((100, 100), torch.bfloat16)
+ assert shared.tensor.dtype == torch.bfloat16
+ shared.drop()
+
+ def test_bfloat16_from_tensor(self):
+ """Test creating shared tensor from bfloat16 tensor"""
+ original = torch.randn(50, 50, dtype=torch.bfloat16)
+ shared = SharedTensor(tensor=original)
+
+ assert shared.tensor.dtype == torch.bfloat16
+ assert torch.allclose(shared.tensor.float(), original.float(), rtol=1e-3)
+
+ shared.drop()
+
+ def test_bfloat16_handle_preservation(self):
+ """Test bfloat16 dtype preserved through handle"""
+ shared1 = SharedTensor.empty((20, 20), torch.bfloat16)
+ shared1.tensor.normal_(0, 1)
+
+ handle = shared1.get_handle()
+ shared2 = SharedTensor(handle=handle)
+
+ assert shared2.tensor.dtype == torch.bfloat16
+ assert torch.allclose(shared1.tensor.float(), shared2.tensor.float(), rtol=1e-3)
+
+ shared1.drop()
+
+ def test_bfloat16_operations(self):
+ """Test operations on bfloat16 tensors"""
+ shared = SharedTensor.empty((100, 100), torch.bfloat16)
+ tensor = shared.tensor
+
+ tensor.normal_(0, 1)
+ mean = tensor.float().mean().item()
+
+ # Mean should be close to 0
+ assert abs(mean) < 0.1
+
+ shared.drop()
+
+
+class TestSharedTensorCloseAndCleanup:
+ """Test explicit close() and cleanup patterns to prevent memory leaks"""
+
+ def test_close_method(self):
+ """Test explicit close() releases handle and sets closed state"""
+ shared = SharedTensor.empty((10, 10), torch.float32)
+ shared.tensor.fill_(5.0)
+
+ assert not shared.is_closed
+
+ # Close should not raise
+ shared.close()
+
+ assert shared.is_closed
+
+ # Cleanup
+ shared._shm.unlink()
+
+ def test_tensor_access_after_close_raises_error(self):
+ """Test that accessing tensor after close raises RuntimeError"""
+ shared = SharedTensor.empty((10, 10), torch.float32)
+ shared.tensor.fill_(5.0)
+
+ shared.close()
+
+ with pytest.raises(RuntimeError, match="Cannot access tensor after close"):
+ _ = shared.tensor
+
+ # Cleanup
+ shared._shm.unlink()
+
+ def test_get_handle_after_close_raises_error(self):
+ """Test that getting handle after close raises RuntimeError"""
+ shared = SharedTensor.empty((10, 10), torch.float32)
+
+ shared.close()
+
+ with pytest.raises(RuntimeError, match="Cannot get handle after close"):
+ shared.get_handle()
+
+ # Cleanup
+ shared._shm.unlink()
+
+ def test_is_closed_property(self):
+ """Test is_closed property reflects state correctly"""
+ shared = SharedTensor.empty((10, 10), torch.float32)
+
+ assert not shared.is_closed
+
+ shared.close()
+
+ assert shared.is_closed
+
+ # Cleanup
+ shared._shm.unlink()
+
+ def test_cached_tensor_reference_becomes_invalid_after_close(self):
+ """Test that tensor reference obtained before close becomes invalid after close"""
+ shared = SharedTensor.empty((10, 10), torch.float32)
+ shared.tensor.fill_(5.0)
+
+ # Get reference before close
+ tensor_ref = shared.tensor
+
+ shared.close()
+
+ # After close(), the memory mapping is unmapped, so even cached references
+ # point to invalid memory. Accessing them will cause segfault or undefined behavior.
+ # We can't safely test this, but we document it.
+
+ # Accessing via shared.tensor raises error (this is what we CAN test)
+ with pytest.raises(RuntimeError):
+ _ = shared.tensor
+
+ # Cleanup
+ shared._shm.unlink()
+
+ def test_context_manager(self):
+ """Test context manager automatically closes"""
+ shm_name = None
+
+ with SharedTensor.empty((10, 10), torch.float32) as shared:
+ shm_name = shared._shm_name
+ shared.tensor.fill_(7.0)
+ assert torch.all(shared.tensor == 7.0)
+
+ # After exiting context, should be closed (but not unlinked yet)
+ # We need to unlink separately
+ from multiprocessing import shared_memory
+
+ # Should still be able to attach (not unlinked)
+ shm = shared_memory.SharedMemory(name=shm_name)
+ shm.close()
+ shm.unlink()
+
+ def test_creator_receiver_workflow(self):
+ """Test proper workflow: creator creates, gets handle, closes, receiver uses and closes"""
+
+ def receiver_process(handle, result_queue):
+ # Receiver creates SharedTensor from handle
+ with SharedTensor(handle=handle) as shared:
+ result = shared.tensor.sum().item()
+ result_queue.put(result)
+ # Context manager auto-closes
+
+ # Creator process
+ shared = SharedTensor.empty((50, 50), torch.float32)
+ shared.tensor.fill_(4.0)
+ handle = shared.get_handle()
+ shared.close() # Creator closes its reference
+
+ # Pass to receiver
+ result_queue = Queue()
+ p = Process(target=receiver_process, args=(handle, result_queue))
+ p.start()
+ p.join()
+
+ result = result_queue.get()
+ assert abs(result - (4.0 * 50 * 50)) < 1e-5
+
+ # Unlink after all processes done
+ handle.drop()
+
+ def test_handle_drop_without_creating_shared_tensor(self):
+ """Test that handle.drop() doesn't create unnecessary SharedTensor instance"""
+ shared = SharedTensor.empty((10, 10), torch.float32)
+ shared.tensor.fill_(3.0)
+ handle = shared.get_handle()
+ shared.close()
+
+ # drop() should work without creating new SharedTensor
+ handle.drop()
+
+ # Memory should be unlinked
+ from multiprocessing import shared_memory
+
+ with pytest.raises(FileNotFoundError):
+ shared_memory.SharedMemory(name=handle.shm_name)
+
+ def test_multiple_receivers_close_independently(self):
+ """Test that multiple receivers can close independently"""
+
+ def receiver_process(handle, value, result_queue):
+ with SharedTensor(handle=handle) as shared:
+ result = shared.tensor[0, 0].item() == value
+ result_queue.put(result)
+
+ # Creator
+ shared = SharedTensor.empty((10, 10), torch.float32)
+ shared.tensor.fill_(9.0)
+ handle = shared.get_handle()
+ shared.close()
+
+ # Multiple receivers
+ result_queue = Queue()
+ processes = []
+ for _ in range(3):
+ p = Process(target=receiver_process, args=(handle, 9.0, result_queue))
+ p.start()
+ processes.append(p)
+
+ for p in processes:
+ p.join()
+
+ # All should succeed
+ for _ in range(3):
+ assert result_queue.get() is True
+
+ # Cleanup
+ handle.drop()
+
+ def test_close_is_idempotent(self):
+ """Test that calling close() multiple times is safe"""
+ shared = SharedTensor.empty((10, 10), torch.float32)
+
+ # Multiple closes should not raise
+ shared.close()
+ shared.close()
+ shared.close()
+
+ # Cleanup
+ shared.drop()
+
+ def test_drop_is_idempotent(self):
+ """Test that calling drop() multiple times is safe"""
+ shared = SharedTensor.empty((10, 10), torch.float32)
+ handle = shared.get_handle()
+ shared.close()
+
+ # Multiple drops should not raise
+ handle.drop()
+ handle.drop()
+ handle.drop()
+
+ def test_proper_cleanup_prevents_leak(self):
+ """Test that proper close + unlink pattern doesn't leak"""
+ import glob
+
+ # Get initial shared memory count
+ shm_before = len(glob.glob("/dev/shm/shared_tensor_*"))
+
+ # Create and properly cleanup 10 shared tensors
+ for _ in range(10):
+ shared = SharedTensor.empty((100, 100), torch.float32)
+ handle = shared.get_handle()
+ shared.close()
+ handle.drop()
+
+ # Check no leaks
+ shm_after = len(glob.glob("/dev/shm/shared_tensor_*"))
+ assert (
+ shm_after == shm_before
+ ), f"Memory leak detected: {shm_after - shm_before} tensors leaked"
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v", "--tb=short"])