-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
Fix make sure enabling array_api_dispatch=True does not break any estimator on NumPy inputs
#32846
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
40c7c82
5f865e1
30a67a9
bd5b9fb
7527de4
e16f1f3
bb4702b
474f49f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| - Fixed a bug that would cause Cython-based estimators to fail when fit on | ||
| NumPy inputs when setting `sklearn.set_config(array_api_dispatch=True)`. By | ||
| :user:`Olivier Grisel <ogrisel>`. | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -23,6 +23,11 @@ | |||||
| __all__ = ["xpx"] # we import xpx here just to re-export it, need this to appease ruff | ||||||
|
|
||||||
| _NUMPY_NAMESPACE_NAMES = {"numpy", "sklearn.externals.array_api_compat.numpy"} | ||||||
| REMOVE_TYPES_DEFAULT = ( | ||||||
| str, | ||||||
| list, | ||||||
| tuple, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| def yield_namespaces(include_numpy_namespaces=True): | ||||||
|
|
@@ -167,7 +172,7 @@ def _single_array_device(array): | |||||
| return array.device | ||||||
|
|
||||||
|
|
||||||
| def device(*array_list, remove_none=True, remove_types=(str,)): | ||||||
| def device(*array_list, remove_none=True, remove_types=REMOVE_TYPES_DEFAULT): | ||||||
| """Hardware device where the array data resides on. | ||||||
|
|
||||||
| If the hardware device is not the same for all arrays, an error is raised. | ||||||
|
|
@@ -180,7 +185,7 @@ def device(*array_list, remove_none=True, remove_types=(str,)): | |||||
| remove_none : bool, default=True | ||||||
| Whether to ignore None objects passed in array_list. | ||||||
|
|
||||||
| remove_types : tuple or list, default=(str,) | ||||||
| remove_types : tuple or list, default=(str, list, tuple) | ||||||
| Types to ignore in array_list. | ||||||
|
|
||||||
| Returns | ||||||
|
|
@@ -290,7 +295,7 @@ def supported_float_dtypes(xp, device=None): | |||||
| return tuple(valid_float_dtypes) | ||||||
|
|
||||||
|
|
||||||
| def _remove_non_arrays(*arrays, remove_none=True, remove_types=(str,)): | ||||||
| def _remove_non_arrays(*arrays, remove_none=True, remove_types=REMOVE_TYPES_DEFAULT): | ||||||
| """Filter arrays to exclude None and/or specific types. | ||||||
|
|
||||||
| Sparse arrays are always filtered out. | ||||||
|
|
@@ -303,7 +308,7 @@ def _remove_non_arrays(*arrays, remove_none=True, remove_types=(str,)): | |||||
| remove_none : bool, default=True | ||||||
| Whether to ignore None objects passed in arrays. | ||||||
|
|
||||||
| remove_types : tuple or list, default=(str,) | ||||||
| remove_types : tuple or list, (str, list, tuple) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| Types to ignore in the arrays. | ||||||
|
|
||||||
| Returns | ||||||
|
|
@@ -328,7 +333,27 @@ def _remove_non_arrays(*arrays, remove_none=True, remove_types=(str,)): | |||||
| return filtered_arrays | ||||||
|
|
||||||
|
|
||||||
| def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None): | ||||||
| def _unwrap_memoryviewslices(*arrays): | ||||||
ogrisel marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| # Since _cyutility._memoryviewslice is an implementation detail of the | ||||||
| # Cython runtime, we would rather not introduce a possibly brittle | ||||||
| # import statement to run `isinstance`-based filtering, hence the | ||||||
| # attribute-based type inspection. | ||||||
| unwrapped = [] | ||||||
| for a in arrays: | ||||||
| a_type = type(a) | ||||||
| if ( | ||||||
| a_type.__module__ == "_cyutility" | ||||||
| and a_type.__name__ == "_memoryviewslice" | ||||||
| and hasattr(a, "base") | ||||||
| ): | ||||||
| a = a.base | ||||||
| unwrapped.append(a) | ||||||
| return unwrapped | ||||||
|
|
||||||
|
|
||||||
| def get_namespace( | ||||||
| *arrays, remove_none=True, remove_types=REMOVE_TYPES_DEFAULT, xp=None | ||||||
| ): | ||||||
| """Get namespace of arrays. | ||||||
|
|
||||||
| Introspect `arrays` arguments and return their common Array API compatible | ||||||
|
|
@@ -364,7 +389,7 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None): | |||||
| remove_none : bool, default=True | ||||||
| Whether to ignore None objects passed in arrays. | ||||||
|
|
||||||
| remove_types : tuple or list, default=(str,) | ||||||
| remove_types : tuple or list, (str, list, tuple) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| Types to ignore in the arrays. | ||||||
|
|
||||||
| xp : module, default=None | ||||||
|
|
@@ -399,12 +424,19 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None): | |||||
| remove_types=remove_types, | ||||||
| ) | ||||||
|
|
||||||
| # get_namespace can be called by helper functions that are used both in | ||||||
| # array API compatible code and non-array API Cython related code. To | ||||||
| # support the latter on NumPy inputs without raising a TypeError, we | ||||||
| # unwrap potential Cython memoryview slices here. | ||||||
| arrays = _unwrap_memoryviewslices(*arrays) | ||||||
|
|
||||||
| if not arrays: | ||||||
| return np_compat, False | ||||||
|
|
||||||
| _check_array_api_dispatch(array_api_dispatch) | ||||||
|
|
||||||
| namespace, is_array_api_compliant = array_api_compat.get_namespace(*arrays), True | ||||||
| namespace = array_api_compat.get_namespace(*arrays) | ||||||
| is_array_api_compliant = True | ||||||
|
|
||||||
| if namespace.__name__ == "array_api_strict" and hasattr( | ||||||
| namespace, "set_array_api_strict_flags" | ||||||
|
|
@@ -415,7 +447,7 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None): | |||||
|
|
||||||
|
|
||||||
| def get_namespace_and_device( | ||||||
| *array_list, remove_none=True, remove_types=(str,), xp=None | ||||||
| *array_list, remove_none=True, remove_types=REMOVE_TYPES_DEFAULT, xp=None | ||||||
| ): | ||||||
| """Combination into one single function of `get_namespace` and `device`. | ||||||
|
|
||||||
|
|
@@ -425,7 +457,7 @@ def get_namespace_and_device( | |||||
| Array objects. | ||||||
| remove_none : bool, default=True | ||||||
| Whether to ignore None objects passed in arrays. | ||||||
| remove_types : tuple or list, default=(str,) | ||||||
| remove_types : tuple or list, (str, list, tuple) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| Types to ignore in the arrays. | ||||||
| xp : module, default=None | ||||||
| Precomputed array namespace module. When passed, typically from a caller | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -196,9 +196,11 @@ def _yield_checks(estimator): | |||||||
| yield check_estimators_pickle | ||||||||
| yield partial(check_estimators_pickle, readonly_memmap=True) | ||||||||
|
|
||||||||
| if tags.array_api_support: | ||||||||
| for check in _yield_array_api_checks(estimator): | ||||||||
| yield check | ||||||||
| for check in _yield_array_api_checks( | ||||||||
| estimator, | ||||||||
| only_numpy=not tags.array_api_support, | ||||||||
| ): | ||||||||
| yield check | ||||||||
|
|
||||||||
| yield check_f_contiguous_array_estimator | ||||||||
|
|
||||||||
|
|
@@ -336,18 +338,30 @@ def _yield_outliers_checks(estimator): | |||||||
| yield check_non_transformer_estimators_n_iter | ||||||||
|
|
||||||||
|
|
||||||||
| def _yield_array_api_checks(estimator): | ||||||||
| for ( | ||||||||
| array_namespace, | ||||||||
| device, | ||||||||
| dtype_name, | ||||||||
| ) in yield_namespace_device_dtype_combinations(): | ||||||||
| def _yield_array_api_checks(estimator, only_numpy=False): | ||||||||
| if only_numpy: | ||||||||
| # Enabling array API dispatch and feeding the NumPy inputs should yield | ||||||||
| # consistent results, even if estimator does not explicitly support | ||||||||
| # array API. | ||||||||
| yield partial( | ||||||||
| check_array_api_input, | ||||||||
| array_namespace=array_namespace, | ||||||||
| dtype_name=dtype_name, | ||||||||
| device=device, | ||||||||
| array_namespace="numpy", | ||||||||
| expect_only_array_outputs=False, | ||||||||
| ) | ||||||||
| else: | ||||||||
| # These extended checks should pass for all estimators that declare | ||||||||
| # array API support in their tags. | ||||||||
| for ( | ||||||||
| array_namespace, | ||||||||
| device, | ||||||||
| dtype_name, | ||||||||
| ) in yield_namespace_device_dtype_combinations(): | ||||||||
| yield partial( | ||||||||
| check_array_api_input, | ||||||||
| array_namespace=array_namespace, | ||||||||
| dtype_name=dtype_name, | ||||||||
| device=device, | ||||||||
| ) | ||||||||
|
|
||||||||
|
|
||||||||
| def _yield_all_checks(estimator, legacy: bool): | ||||||||
|
|
@@ -1048,6 +1062,7 @@ def check_array_api_input( | |||||||
| dtype_name="float64", | ||||||||
| check_values=False, | ||||||||
| check_sample_weight=False, | ||||||||
| expect_only_array_outputs=True, | ||||||||
| ): | ||||||||
| """Check that the estimator can work consistently with the Array API | ||||||||
|
|
||||||||
|
|
@@ -1057,17 +1072,24 @@ def check_array_api_input( | |||||||
| When check_values is True, it also checks that calling the estimator on the | ||||||||
| array_api Array gives the same results as ndarrays. | ||||||||
|
|
||||||||
| When sample_weight is True, dummy sample weights are passed to the fit call. | ||||||||
| When check_sample_weight is True, dummy sample weights are passed to the | ||||||||
| fit call. | ||||||||
|
|
||||||||
| When expect_only_array_outputs is False, the check accepts non-array | ||||||||
| outputs from estimator methods (e.g., sparse data structures). This is | ||||||||
|
Comment on lines
+1078
to
+1079
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you expand on this? I was confused just reading this, but https://github.com/scikit-learn/scikit-learn/pull/32846/changes#r2602981393 cleared it up for me. |
||||||||
| useful to test that enabling array API dispatch does not break the | ||||||||
| estimator, even if the estimator does not support array API. | ||||||||
| """ | ||||||||
| xp = _array_api_for_tests(array_namespace, device) | ||||||||
|
|
||||||||
| X, y = make_classification(random_state=42) | ||||||||
| X, y = make_classification(n_samples=30, n_features=10, random_state=42) | ||||||||
| X = X.astype(dtype_name, copy=False) | ||||||||
|
|
||||||||
| X = _enforce_estimator_tags_X(estimator_orig, X) | ||||||||
| y = _enforce_estimator_tags_y(estimator_orig, y) | ||||||||
|
|
||||||||
| est = clone(estimator_orig) | ||||||||
| set_random_state(est) | ||||||||
|
|
||||||||
| X_xp = xp.asarray(X, device=device) | ||||||||
| y_xp = xp.asarray(y, device=device) | ||||||||
|
|
@@ -1193,47 +1215,47 @@ def check_array_api_input( | |||||||
| f"got {result_ns}." | ||||||||
| ) | ||||||||
|
|
||||||||
| with config_context(array_api_dispatch=True): | ||||||||
| assert array_device(result_xp) == array_device(X_xp) | ||||||||
|
|
||||||||
| result_xp_np = _convert_to_numpy(result_xp, xp=xp) | ||||||||
| if expect_only_array_outputs: | ||||||||
| with config_context(array_api_dispatch=True): | ||||||||
| assert array_device(result_xp) == array_device(X_xp) | ||||||||
|
|
||||||||
| if check_values: | ||||||||
| assert_allclose( | ||||||||
| result, | ||||||||
| result_xp_np, | ||||||||
| err_msg=f"{method} did not the return the same result", | ||||||||
| atol=_atol_for_type(X.dtype), | ||||||||
| ) | ||||||||
| else: | ||||||||
| if hasattr(result, "shape"): | ||||||||
| result_xp_np = _convert_to_numpy(result_xp, xp=xp) | ||||||||
| if check_values: | ||||||||
| assert_allclose( | ||||||||
| result, | ||||||||
| result_xp_np, | ||||||||
| err_msg=f"{method} did not the return the same result", | ||||||||
| atol=_atol_for_type(X.dtype), | ||||||||
| ) | ||||||||
| elif hasattr(result, "shape"): | ||||||||
| assert result.shape == result_xp_np.shape | ||||||||
| assert result.dtype == result_xp_np.dtype | ||||||||
|
|
||||||||
| if method_name == "transform" and hasattr(est, "inverse_transform"): | ||||||||
| inverse_result = est.inverse_transform(result) | ||||||||
| with config_context(array_api_dispatch=True): | ||||||||
| invese_result_xp = est_xp.inverse_transform(result_xp) | ||||||||
| inverse_result_ns = get_namespace(invese_result_xp)[0].__name__ | ||||||||
| assert inverse_result_ns == input_ns, ( | ||||||||
| "'inverse_transform' output is in wrong namespace, expected" | ||||||||
| f" {input_ns}, got {inverse_result_ns}." | ||||||||
| ) | ||||||||
|
|
||||||||
| with config_context(array_api_dispatch=True): | ||||||||
| assert array_device(invese_result_xp) == array_device(X_xp) | ||||||||
|
|
||||||||
| invese_result_xp_np = _convert_to_numpy(invese_result_xp, xp=xp) | ||||||||
| if check_values: | ||||||||
| assert_allclose( | ||||||||
| inverse_result, | ||||||||
| invese_result_xp_np, | ||||||||
| err_msg="inverse_transform did not the return the same result", | ||||||||
| atol=_atol_for_type(X.dtype), | ||||||||
| inverse_result_xp = est_xp.inverse_transform(result_xp) | ||||||||
|
|
||||||||
| if expect_only_array_outputs: | ||||||||
| with config_context(array_api_dispatch=True): | ||||||||
| inverse_result_ns = get_namespace(inverse_result_xp)[0].__name__ | ||||||||
| assert inverse_result_ns == input_ns, ( | ||||||||
| "'inverse_transform' output is in wrong namespace, expected" | ||||||||
| f" {input_ns}, got {inverse_result_ns}." | ||||||||
| ) | ||||||||
| else: | ||||||||
| assert inverse_result.shape == invese_result_xp_np.shape | ||||||||
| assert inverse_result.dtype == invese_result_xp_np.dtype | ||||||||
| assert array_device(inverse_result_xp) == array_device(X_xp) | ||||||||
ogrisel marked this conversation as resolved.
Show resolved
Hide resolved
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you need a Maybe we should do something about it because it's easy to fall into this trap. One potential idea was to add an
Suggested change
|
||||||||
|
|
||||||||
| inverse_result_xp_np = _convert_to_numpy(inverse_result_xp, xp=xp) | ||||||||
| if check_values: | ||||||||
| assert_allclose( | ||||||||
| inverse_result, | ||||||||
| inverse_result_xp_np, | ||||||||
| err_msg="inverse_transform did not the return the same result", | ||||||||
| atol=_atol_for_type(X.dtype), | ||||||||
| ) | ||||||||
| elif hasattr(result, "shape"): | ||||||||
| assert inverse_result.shape == inverse_result_xp_np.shape | ||||||||
| assert inverse_result.dtype == inverse_result_xp_np.dtype | ||||||||
|
|
||||||||
|
|
||||||||
| def check_array_api_input_and_values( | ||||||||
|
|
||||||||
Uh oh!
There was an error while loading. Please reload this page.