Skip to content

CI wheel builder failure on Linux Python 3.11 3.12 and 3.13 with BrokenProcessPool #32599

@lesteve

Description

@lesteve

last successful on October 22 build log

first failing on October 23 build log

Unclear what the root cause is for now (identical versions it seems, the Python 3.11 build has a different Python bugfix version but I think this is a red herring since Python 3.12 and Python 3.13 have the same bugfix version) ...

Failing tests:


  FAILED inspection/tests/test_permutation_importance.py::test_permutation_importance_equivalence_array_dataframe[0.5-2]
  FAILED inspection/tests/test_permutation_importance.py::test_permutation_importance_equivalence_array_dataframe[1.0-2]
________ test_permutation_importance_equivalence_array_dataframe[0.5-2] ________
  joblib.externals.loky.process_executor._RemoteTraceback: 
  """
  Traceback (most recent call last):
    File "/tmp/tmp.q3sHOyewIE/venv/lib/python3.13/site-packages/joblib/externals/loky/process_executor.py", line 453, in _process_worker
      call_item = call_queue.get(block=True, timeout=timeout)
    File "/opt/_internal/cpython-3.13.8/lib/python3.13/multiprocessing/queues.py", line 120, in get
      return _ForkingPickler.loads(res)
             ~~~~~~~~~~~~~~~~~~~~~^^^^^
    File "/tmp/tmp.q3sHOyewIE/venv/lib/python3.13/site-packages/pandas/core/arrays/categorical.py", line 1726, in __setstate__
      return super().__setstate__(state)
             ~~~~~~~~~~~~~~~~~~~~^^^^^^^
    File "pandas/_libs/arrays.pyx", line 85, in pandas._libs.arrays.NDArrayBacked.__setstate__
    File "pandas/_libs/arrays.pyx", line 103, in pandas._libs.arrays.NDArrayBacked.__setstate__
  NotImplementedError: (CategoricalDtype(categories=[0.0, 1.0, 2.0], ordered=False, categories_dtype=float64), array([1, 0, 2, 1, 0, 2, 2, 1, 0, 1, 2, 1, 2, 0, 0, 1, 1, 0, 1, 2, 2, 0,
         1, 2, 2, 0, 2, 1, 2, 0, 0, 2, 2, 2, 2, 2, 1, 1, 1, 1, 0, 1, 1, 0,
         1, 0, 1, 1, 1, 2, 1, 0, 1, 0, 2, 0, 1, 1, 2, 1, 1, 1, 0, 0, 0, 2,
         0, 0, 2, 2, 0, 0, 0, 0, 2, 2, 0, 0, 2, 2, 1, 0, 0, 0, 1, 2, 0, 2,
         1, 0, 0, 1, 1, 2, 2, 2, 1, 2, 2, 2], dtype=int8))
  """
  
  The above exception was the direct cause of the following exception:
  
  n_jobs = 2, max_samples = 0.5
  
      @pytest.mark.parametrize("n_jobs", [None, 1, 2])
      @pytest.mark.parametrize("max_samples", [0.5, 1.0])
      def test_permutation_importance_equivalence_array_dataframe(n_jobs, max_samples):
          # This test checks that the column shuffling logic has the same behavior
          # both a dataframe and a simple numpy array.
          pd = pytest.importorskip("pandas")
      
          # regression test to make sure that sequential and parallel calls will
          # output the same results.
          X, y = make_regression(n_samples=100, n_features=5, random_state=0)
          X_df = pd.DataFrame(X)
      
          # Add a categorical feature that is statistically linked to y:
          binner = KBinsDiscretizer(
              n_bins=3,
              encode="ordinal",
              quantile_method="averaged_inverted_cdf",
          )
          cat_column = binner.fit_transform(y.reshape(-1, 1))
      
          # Concatenate the extra column to the numpy array: integers will be
          # cast to float values
          X = np.hstack([X, cat_column])
          assert X.dtype.kind == "f"
      
          # Insert extra column as a non-numpy-native dtype:
          cat_column = pd.Categorical(cat_column.ravel())
          new_col_idx = len(X_df.columns)
          X_df[new_col_idx] = cat_column
          assert X_df[new_col_idx].dtype == cat_column.dtype
      
          # Stich an arbitrary index to the dataframe:
          X_df.index = np.arange(len(X_df)).astype(str)
      
          rf = RandomForestRegressor(n_estimators=5, max_depth=3, random_state=0)
          rf.fit(X, y)
      
          n_repeats = 3
          importance_array = permutation_importance(
              rf,
              X,
              y,
              n_repeats=n_repeats,
              random_state=0,
              n_jobs=n_jobs,
              max_samples=max_samples,
          )
      
          # First check that the problem is structured enough and that the model is
          # complex enough to not yield trivial, constant importances:
          imp_min = importance_array["importances"].min()
          imp_max = importance_array["importances"].max()
          assert imp_max - imp_min > 0.3
      
          # Now check that importances computed on dataframe matche the values
          # of those computed on the array with the same data.
  >       importance_dataframe = permutation_importance(
              rf,
              X_df,
              y,
              n_repeats=n_repeats,
              random_state=0,
              n_jobs=n_jobs,
              max_samples=max_samples,
          )
  
  ../venv/lib/python3.13/site-packages/sklearn/inspection/tests/test_permutation_importance.py:357: 
  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
  ../venv/lib/python3.13/site-packages/sklearn/utils/_param_validation.py:218: in wrapper
      return func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  ../venv/lib/python3.13/site-packages/sklearn/inspection/_permutation_importance.py:288: in permutation_importance
      scores = Parallel(n_jobs=n_jobs)(
  ../venv/lib/python3.13/site-packages/sklearn/utils/parallel.py:91: in __call__
      return super().__call__(iterable_with_config_and_warning_filters)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  ../venv/lib/python3.13/site-packages/joblib/parallel.py:2072: in __call__
      return output if self.return_generator else list(output)
                                                  ^^^^^^^^^^^^
  ../venv/lib/python3.13/site-packages/joblib/parallel.py:1682: in _get_outputs
      yield from self._retrieve()
  ../venv/lib/python3.13/site-packages/joblib/parallel.py:1784: in _retrieve
      self._raise_error_fast()
  ../venv/lib/python3.13/site-packages/joblib/parallel.py:1859: in _raise_error_fast
      error_job.get_result(self.timeout)
  ../venv/lib/python3.13/site-packages/joblib/parallel.py:758: in get_result
      return self._return_or_raise()
             ^^^^^^^^^^^^^^^^^^^^^^^
  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
  
  self = <joblib.parallel.BatchCompletionCallBack object at 0x7f8c44e4d6d0>
  
      def _return_or_raise(self):
          try:
              if self.status == TASK_ERROR:
  >               raise self._result
  E               joblib.externals.loky.process_executor.BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.
  
  ../venv/lib/python3.13/site-packages/joblib/parallel.py:773: BrokenProcessPool
  ________ test_permutation_importance_equivalence_array_dataframe[1.0-2] ________
  joblib.externals.loky.process_executor._RemoteTraceback: 
  """
  Traceback (most recent call last):
    File "/tmp/tmp.q3sHOyewIE/venv/lib/python3.13/site-packages/joblib/externals/loky/process_executor.py", line 453, in _process_worker
      call_item = call_queue.get(block=True, timeout=timeout)
    File "/opt/_internal/cpython-3.13.8/lib/python3.13/multiprocessing/queues.py", line 120, in get
      return _ForkingPickler.loads(res)
             ~~~~~~~~~~~~~~~~~~~~~^^^^^
    File "/tmp/tmp.q3sHOyewIE/venv/lib/python3.13/site-packages/pandas/core/arrays/categorical.py", line 1726, in __setstate__
      return super().__setstate__(state)
             ~~~~~~~~~~~~~~~~~~~~^^^^^^^
    File "pandas/_libs/arrays.pyx", line 85, in pandas._libs.arrays.NDArrayBacked.__setstate__
    File "pandas/_libs/arrays.pyx", line 103, in pandas._libs.arrays.NDArrayBacked.__setstate__
  NotImplementedError: (CategoricalDtype(categories=[0.0, 1.0, 2.0], ordered=False, categories_dtype=float64), array([1, 0, 2, 1, 0, 2, 2, 1, 0, 1, 2, 1, 2, 0, 0, 1, 1, 0, 1, 2, 2, 0,
         1, 2, 2, 0, 2, 1, 2, 0, 0, 2, 2, 2, 2, 2, 1, 1, 1, 1, 0, 1, 1, 0,
         1, 0, 1, 1, 1, 2, 1, 0, 1, 0, 2, 0, 1, 1, 2, 1, 1, 1, 0, 0, 0, 2,
         0, 0, 2, 2, 0, 0, 0, 0, 2, 2, 0, 0, 2, 2, 1, 0, 0, 0, 1, 2, 0, 2,
         1, 0, 0, 1, 1, 2, 2, 2, 1, 2, 2, 2], dtype=int8))
  """
  
  The above exception was the direct cause of the following exception:
  
  n_jobs = 2, max_samples = 1.0
  
      @pytest.mark.parametrize("n_jobs", [None, 1, 2])
      @pytest.mark.parametrize("max_samples", [0.5, 1.0])
      def test_permutation_importance_equivalence_array_dataframe(n_jobs, max_samples):
          # This test checks that the column shuffling logic has the same behavior
          # both a dataframe and a simple numpy array.
          pd = pytest.importorskip("pandas")
      
          # regression test to make sure that sequential and parallel calls will
          # output the same results.
          X, y = make_regression(n_samples=100, n_features=5, random_state=0)
          X_df = pd.DataFrame(X)
      
          # Add a categorical feature that is statistically linked to y:
          binner = KBinsDiscretizer(
              n_bins=3,
              encode="ordinal",
              quantile_method="averaged_inverted_cdf",
          )
          cat_column = binner.fit_transform(y.reshape(-1, 1))
      
          # Concatenate the extra column to the numpy array: integers will be
          # cast to float values
          X = np.hstack([X, cat_column])
          assert X.dtype.kind == "f"
      
          # Insert extra column as a non-numpy-native dtype:
          cat_column = pd.Categorical(cat_column.ravel())
          new_col_idx = len(X_df.columns)
          X_df[new_col_idx] = cat_column
          assert X_df[new_col_idx].dtype == cat_column.dtype
      
          # Stich an arbitrary index to the dataframe:
          X_df.index = np.arange(len(X_df)).astype(str)
      
          rf = RandomForestRegressor(n_estimators=5, max_depth=3, random_state=0)
          rf.fit(X, y)
      
          n_repeats = 3
          importance_array = permutation_importance(
              rf,
              X,
              y,
              n_repeats=n_repeats,
              random_state=0,
              n_jobs=n_jobs,
              max_samples=max_samples,
          )
      
          # First check that the problem is structured enough and that the model is
          # complex enough to not yield trivial, constant importances:
          imp_min = importance_array["importances"].min()
          imp_max = importance_array["importances"].max()
          assert imp_max - imp_min > 0.3
      
          # Now check that importances computed on dataframe matche the values
          # of those computed on the array with the same data.
  >       importance_dataframe = permutation_importance(
              rf,
              X_df,
              y,
              n_repeats=n_repeats,
              random_state=0,
              n_jobs=n_jobs,
              max_samples=max_samples,
          )
  
  ../venv/lib/python3.13/site-packages/sklearn/inspection/tests/test_permutation_importance.py:357: 
  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
  ../venv/lib/python3.13/site-packages/sklearn/utils/_param_validation.py:218: in wrapper
      return func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  ../venv/lib/python3.13/site-packages/sklearn/inspection/_permutation_importance.py:288: in permutation_importance
      scores = Parallel(n_jobs=n_jobs)(
  ../venv/lib/python3.13/site-packages/sklearn/utils/parallel.py:91: in __call__
      return super().__call__(iterable_with_config_and_warning_filters)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  ../venv/lib/python3.13/site-packages/joblib/parallel.py:2072: in __call__
      return output if self.return_generator else list(output)
                                                  ^^^^^^^^^^^^
  ../venv/lib/python3.13/site-packages/joblib/parallel.py:1682: in _get_outputs
      yield from self._retrieve()
  ../venv/lib/python3.13/site-packages/joblib/parallel.py:1784: in _retrieve
      self._raise_error_fast()
  ../venv/lib/python3.13/site-packages/joblib/parallel.py:1859: in _raise_error_fast
      error_job.get_result(self.timeout)
  ../venv/lib/python3.13/site-packages/joblib/parallel.py:758: in get_result
      return self._return_or_raise()
             ^^^^^^^^^^^^^^^^^^^^^^^
  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
  
  self = <joblib.parallel.BatchCompletionCallBack object at 0x7f8c44e4ef90>
  
      def _return_or_raise(self):
          try:
              if self.status == TASK_ERROR:
  >               raise self._result
  E               joblib.externals.loky.process_executor.BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.
  
  ../venv/lib/python3.13/site-packages/joblib/parallel.py:773: BrokenProcessPool

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions