diff --git a/src/dstack/_internal/server/background/__init__.py b/src/dstack/_internal/server/background/__init__.py index 8577cce6f..05967cac8 100644 --- a/src/dstack/_internal/server/background/__init__.py +++ b/src/dstack/_internal/server/background/__init__.py @@ -1,3 +1,5 @@ +import os + from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.triggers.interval import IntervalTrigger @@ -33,6 +35,10 @@ process_terminating_jobs, ) from dstack._internal.server.background.tasks.process_volumes import process_submitted_volumes +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + _scheduler = AsyncIOScheduler() @@ -74,69 +80,93 @@ def start_background_tasks() -> AsyncIOScheduler: # that the first waiting for the lock will acquire it. # The jitter is needed to give all tasks a chance to acquire locks. - _scheduler.add_job(process_probes, IntervalTrigger(seconds=3, jitter=1)) - _scheduler.add_job(collect_metrics, IntervalTrigger(seconds=10), max_instances=1) - _scheduler.add_job(delete_metrics, IntervalTrigger(minutes=5), max_instances=1) - _scheduler.add_job(delete_events, IntervalTrigger(minutes=7), max_instances=1) + if os.getenv("DSTACK_PROBES_PROCESSING_DISABLED") is None: + _scheduler.add_job(process_probes, IntervalTrigger(seconds=3, jitter=1)) + if os.getenv("DSTACK_COLLECT_METRICS_PROCESSING_DISABLED") is None: + _scheduler.add_job(collect_metrics, IntervalTrigger(seconds=10), max_instances=1) + if os.getenv("DSTACK_DELETE_METRICS_PROCESSING_DISABLED") is None: + _scheduler.add_job(delete_metrics, IntervalTrigger(minutes=5), max_instances=1) + if os.getenv("DSTACK_DELETE_EVENTS_PROCESSING_DISABLED") is None: + _scheduler.add_job(delete_events, IntervalTrigger(minutes=7), max_instances=1) if settings.ENABLE_PROMETHEUS_METRICS: + if os.getenv("DSTACK_COLLECT_PROMETHEUS_METRICS_PROCESSING_DISABLED") is None: + _scheduler.add_job( + collect_prometheus_metrics, IntervalTrigger(seconds=10), max_instances=1 + ) + if os.getenv("DSTACK_DELETE_PROMETHEUS_METRICS_PROCESSING_DISABLED") is None: + _scheduler.add_job( + delete_prometheus_metrics, IntervalTrigger(minutes=5), max_instances=1 + ) + if os.getenv("DSTACK_GATEWAY_PROCESSING_DISABLED") is None: + _scheduler.add_job(process_gateways_connections, IntervalTrigger(seconds=15)) _scheduler.add_job( - collect_prometheus_metrics, IntervalTrigger(seconds=10), max_instances=1 + process_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5 ) - _scheduler.add_job(delete_prometheus_metrics, IntervalTrigger(minutes=5), max_instances=1) - _scheduler.add_job(process_gateways_connections, IntervalTrigger(seconds=15)) - _scheduler.add_job(process_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5) - _scheduler.add_job( - process_submitted_volumes, IntervalTrigger(seconds=10, jitter=2), max_instances=5 - ) - _scheduler.add_job( - process_idle_volumes, IntervalTrigger(seconds=60, jitter=10), max_instances=1 - ) - _scheduler.add_job(process_placement_groups, IntervalTrigger(seconds=30, jitter=5)) - _scheduler.add_job( - process_fleets, - IntervalTrigger(seconds=10, jitter=2), - max_instances=1, - ) - _scheduler.add_job(delete_instance_health_checks, IntervalTrigger(minutes=5), max_instances=1) - for replica in range(settings.SERVER_BACKGROUND_PROCESSING_FACTOR): - # Add multiple copies of tasks if requested. - # max_instances=1 for additional copies to avoid running too many tasks. - # Move other tasks here when they need per-replica scaling. + if os.getenv("DSTACK_SUBMITTED_VOLUMES_PROCESSING_DISABLED") is None: _scheduler.add_job( - process_submitted_jobs, - IntervalTrigger(seconds=4, jitter=2), - kwargs={"batch_size": 5}, - max_instances=4 if replica == 0 else 1, + process_submitted_volumes, IntervalTrigger(seconds=10, jitter=2), max_instances=5 ) + if os.getenv("DSTACK_IDLE_VOLUMES_PROCESSING_DISABLED") is None: _scheduler.add_job( - process_running_jobs, - IntervalTrigger(seconds=4, jitter=2), - kwargs={"batch_size": 5}, - max_instances=2 if replica == 0 else 1, + process_idle_volumes, IntervalTrigger(seconds=60, jitter=10), max_instances=1 ) + if os.getenv("DSTACK_PLACEMENT_GROUPS_PROCESSING_DISABLED") is None: + _scheduler.add_job(process_placement_groups, IntervalTrigger(seconds=30, jitter=5)) + if os.getenv("DSTACK_FLEETS_PROCESSING_DISABLED") is None: _scheduler.add_job( - process_terminating_jobs, - IntervalTrigger(seconds=4, jitter=2), - kwargs={"batch_size": 5}, - max_instances=2 if replica == 0 else 1, + process_fleets, + IntervalTrigger(seconds=10, jitter=2), + max_instances=1, ) + if os.getenv("DSTACK_DELETE_INSTANCE_HEALTH_CHECKS_PROCESSING_DISABLED") is None: _scheduler.add_job( - process_runs, - IntervalTrigger(seconds=2, jitter=1), - kwargs={"batch_size": 5}, - max_instances=2 if replica == 0 else 1, - ) - _scheduler.add_job( - process_instances, - IntervalTrigger(seconds=4, jitter=2), - kwargs={"batch_size": 5}, - max_instances=2 if replica == 0 else 1, - ) - _scheduler.add_job( - process_compute_groups, - IntervalTrigger(seconds=15, jitter=2), - kwargs={"batch_size": 1}, - max_instances=2 if replica == 0 else 1, + delete_instance_health_checks, IntervalTrigger(minutes=5), max_instances=1 ) + for replica in range(settings.SERVER_BACKGROUND_PROCESSING_FACTOR): + # Add multiple copies of tasks if requested. + # max_instances=1 for additional copies to avoid running too many tasks. + # Move other tasks here when they need per-replica scaling. + if os.getenv("DSTACK_SUBMITTED_JOBS_PROCESSING_DISABLED") is None: + _scheduler.add_job( + process_submitted_jobs, + IntervalTrigger(seconds=4, jitter=2), + kwargs={"batch_size": 5}, + max_instances=4 if replica == 0 else 1, + ) + if os.getenv("DSTACK_RUNNING_JOBS_PROCESSING_DISABLED") is None: + _scheduler.add_job( + process_running_jobs, + IntervalTrigger(seconds=4, jitter=2), + kwargs={"batch_size": 5}, + max_instances=2 if replica == 0 else 1, + ) + if os.getenv("DSTACK_TERMINATING_JOBS_PROCESSING_DISABLED") is None: + _scheduler.add_job( + process_terminating_jobs, + IntervalTrigger(seconds=4, jitter=2), + kwargs={"batch_size": 5}, + max_instances=2 if replica == 0 else 1, + ) + if os.getenv("DSTACK_RUNS_PROCESSING_DISABLED") is None: + _scheduler.add_job( + process_runs, + IntervalTrigger(seconds=2, jitter=1), + kwargs={"batch_size": 5}, + max_instances=2 if replica == 0 else 1, + ) + if os.getenv("DSTACK_INSTANCES_PROCESSING_DISABLED") is None: + _scheduler.add_job( + process_instances, + IntervalTrigger(seconds=4, jitter=2), + kwargs={"batch_size": 5}, + max_instances=2 if replica == 0 else 1, + ) + if os.getenv("DSTACK_COMPUTE_GROUPS_PROCESSING_DISABLED") is None: + _scheduler.add_job( + process_compute_groups, + IntervalTrigger(seconds=15, jitter=2), + kwargs={"batch_size": 1}, + max_instances=2 if replica == 0 else 1, + ) _scheduler.start() return _scheduler diff --git a/src/dstack/_internal/server/background/tasks/process_gateways.py b/src/dstack/_internal/server/background/tasks/process_gateways.py index a54cb9e31..3db5a2824 100644 --- a/src/dstack/_internal/server/background/tasks/process_gateways.py +++ b/src/dstack/_internal/server/background/tasks/process_gateways.py @@ -24,6 +24,7 @@ logger = get_logger(__name__) +@sentry_utils.instrument_background_task async def process_gateways_connections(): await _remove_inactive_connections() await _process_active_connections() diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py index 9a14bdc30..454d6ee18 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/tasks/process_instances.py @@ -11,7 +11,7 @@ from pydantic import ValidationError from sqlalchemy import and_, delete, func, not_, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload, with_loader_criteria +from sqlalchemy.orm import joinedload from dstack._internal import settings from dstack._internal.core.backends.base.compute import ( @@ -218,9 +218,8 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel): .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends)) .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) .options( - joinedload(InstanceModel.fleet).joinedload(FleetModel.instances), - with_loader_criteria( - InstanceModel, InstanceModel.deleted == False, include_aliases=True + joinedload(InstanceModel.fleet).joinedload( + FleetModel.instances.and_(InstanceModel.deleted == False) ), ) .execution_options(populate_existing=True) @@ -233,9 +232,8 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel): .options(joinedload(InstanceModel.project)) .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) .options( - joinedload(InstanceModel.fleet).joinedload(FleetModel.instances), - with_loader_criteria( - InstanceModel, InstanceModel.deleted == False, include_aliases=True + joinedload(InstanceModel.fleet).joinedload( + FleetModel.instances.and_(InstanceModel.deleted == False) ), ) .execution_options(populate_existing=True) diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 95ae519d0..cbae143c3 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -1,3 +1,4 @@ +import os import uuid from collections.abc import Callable from datetime import datetime @@ -42,7 +43,12 @@ ) from dstack._internal.core.models.projects import Project from dstack._internal.core.models.resources import ResourcesSpec -from dstack._internal.core.models.runs import JobProvisioningData, Requirements, get_policy_map +from dstack._internal.core.models.runs import ( + JobProvisioningData, + Requirements, + RunStatus, + get_policy_map, +) from dstack._internal.core.models.users import GlobalRole from dstack._internal.core.services import validate_dstack_resource_name from dstack._internal.core.services.diff import ModelDiff, copy_model, diff_models @@ -53,6 +59,7 @@ JobModel, MemberModel, ProjectModel, + RunModel, UserModel, ) from dstack._internal.server.services import events @@ -113,6 +120,10 @@ async def list_projects_with_no_active_fleets( Applies to all users (both regular users and admins require membership). """ + # Testing https://github.com/sqlalchemy/sqlalchemy/discussions/12536 + if os.getenv("DSTACK_LIST_NO_ACTIVE_FLEETS_DISABLED") is not None: + return [] + active_fleet_alias = aliased(FleetModel) member_alias = aliased(MemberModel) @@ -613,48 +624,61 @@ async def delete_fleets( instance_nums: Optional[List[int]] = None, ): res = await session.execute( - select(FleetModel) + select(FleetModel.id) .where( FleetModel.project_id == project.id, FleetModel.name.in_(names), FleetModel.deleted == False, ) - .options(joinedload(FleetModel.instances)) + .order_by(FleetModel.id) # take locks in order + .with_for_update(key_share=True) ) - fleet_models = res.scalars().unique().all() - fleets_ids = sorted([f.id for f in fleet_models]) - instances_ids = sorted([i.id for f in fleet_models for i in f.instances]) - await session.commit() - logger.info("Deleting fleets: %s", [v.name for v in fleet_models]) + fleets_ids = list(res.scalars().unique().all()) + res = await session.execute( + select(InstanceModel.id) + .where( + InstanceModel.fleet_id.in_(fleets_ids), + InstanceModel.deleted == False, + ) + .order_by(InstanceModel.id) # take locks in order + .with_for_update(key_share=True) + ) + instances_ids = list(res.scalars().unique().all()) + if is_db_sqlite(): + # Start new transaction to see committed changes after lock + await session.commit() async with ( get_locker(get_db().dialect_name).lock_ctx(FleetModel.__tablename__, fleets_ids), get_locker(get_db().dialect_name).lock_ctx(InstanceModel.__tablename__, instances_ids), ): - # Refetch after lock - # TODO: Lock instances with FOR UPDATE? - # TODO: Do not lock fleet when deleting only instances + # Refetch after lock. + # TODO: Do not lock fleet when deleting only instances. res = await session.execute( select(FleetModel) - .where( - FleetModel.project_id == project.id, - FleetModel.name.in_(names), - FleetModel.deleted == False, - ) + .where(FleetModel.id.in_(fleets_ids)) .options( - selectinload(FleetModel.instances) + joinedload(FleetModel.instances.and_(InstanceModel.id.in_(instances_ids))) .joinedload(InstanceModel.jobs) .load_only(JobModel.id) ) - .options(selectinload(FleetModel.runs)) + .options( + joinedload( + FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses())) + ) + ) .execution_options(populate_existing=True) - .order_by(FleetModel.id) # take locks in order - .with_for_update(key_share=True) ) fleet_models = res.scalars().unique().all() fleets = [fleet_model_to_fleet(m) for m in fleet_models] for fleet in fleets: if fleet.spec.configuration.ssh_config is not None: _check_can_manage_ssh_fleets(user=user, project=project) + if instance_nums is None: + logger.info("Deleting fleets: %s", [f.name for f in fleet_models]) + else: + logger.info( + "Deleting fleets %s instances %s", [f.name for f in fleet_models], instance_nums + ) for fleet_model in fleet_models: _terminate_fleet_instances(fleet_model=fleet_model, instance_nums=instance_nums) # TERMINATING fleets are deleted by process_fleets after instances are terminated diff --git a/src/tests/_internal/server/background/tasks/test_process_instances.py b/src/tests/_internal/server/background/tasks/test_process_instances.py index a72dc0c16..38bffc442 100644 --- a/src/tests/_internal/server/background/tasks/test_process_instances.py +++ b/src/tests/_internal/server/background/tasks/test_process_instances.py @@ -597,6 +597,34 @@ async def test_terminate(self, test_db, session: AsyncSession): assert instance.deleted_at is not None assert instance.finished_at is not None + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_terminates_terminating_deleted_instance(self, test_db, session: AsyncSession): + # There was a race condition when instance could stay in Terminating while marked as deleted. + # TODO: Drop this after all such "bad" instances are processed. + project = await create_project(session=session) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.TERMINATING + ) + instance.deleted = True + instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT + instance.last_job_processed_at = instance.deleted_at = ( + get_current_datetime() + dt.timedelta(minutes=-19) + ) + await session.commit() + + with self.mock_terminate_in_backend() as mock: + await process_instances() + mock.assert_called_once() + + await session.refresh(instance) + + assert instance is not None + assert instance.status == InstanceStatus.TERMINATED + assert instance.deleted == True + assert instance.deleted_at is not None + assert instance.finished_at is not None + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @pytest.mark.parametrize(