Skip to content
3 changes: 3 additions & 0 deletions doc/whats_new/upcoming_changes/array-api/32846.fix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- Fixed a bug that would cause Cython-based estimators to fail when fit on
NumPy inputs when setting `sklearn.set_config(array_api_dispatch=True)`. By
:user:`Olivier Grisel <ogrisel>`.
2 changes: 1 addition & 1 deletion sklearn/decomposition/_dict_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,7 +1360,7 @@ def fit(self, X, y=None):
if X.shape[1] != self.dictionary.shape[1]:
raise ValueError(
"Dictionary and X have different numbers of features:"
f"dictionary.shape: {self.dictionary.shape} X.shape{X.shape}"
f"dictionary.shape: {self.dictionary.shape} X.shape: {X.shape}"
)
return self

Expand Down
50 changes: 41 additions & 9 deletions sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
__all__ = ["xpx"] # we import xpx here just to re-export it, need this to appease ruff

_NUMPY_NAMESPACE_NAMES = {"numpy", "sklearn.externals.array_api_compat.numpy"}
REMOVE_TYPES_DEFAULT = (
str,
list,
tuple,
)


def yield_namespaces(include_numpy_namespaces=True):
Expand Down Expand Up @@ -167,7 +172,7 @@ def _single_array_device(array):
return array.device


def device(*array_list, remove_none=True, remove_types=(str,)):
def device(*array_list, remove_none=True, remove_types=REMOVE_TYPES_DEFAULT):
"""Hardware device where the array data resides on.

If the hardware device is not the same for all arrays, an error is raised.
Expand All @@ -180,7 +185,7 @@ def device(*array_list, remove_none=True, remove_types=(str,)):
remove_none : bool, default=True
Whether to ignore None objects passed in array_list.

remove_types : tuple or list, default=(str,)
remove_types : tuple or list, default=(str, list, tuple)
Types to ignore in array_list.

Returns
Expand Down Expand Up @@ -290,7 +295,7 @@ def supported_float_dtypes(xp, device=None):
return tuple(valid_float_dtypes)


def _remove_non_arrays(*arrays, remove_none=True, remove_types=(str,)):
def _remove_non_arrays(*arrays, remove_none=True, remove_types=REMOVE_TYPES_DEFAULT):
"""Filter arrays to exclude None and/or specific types.

Sparse arrays are always filtered out.
Expand All @@ -303,7 +308,7 @@ def _remove_non_arrays(*arrays, remove_none=True, remove_types=(str,)):
remove_none : bool, default=True
Whether to ignore None objects passed in arrays.

remove_types : tuple or list, default=(str,)
remove_types : tuple or list, (str, list, tuple)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
remove_types : tuple or list, (str, list, tuple)
remove_types : tuple or list, default=(str, list, tuple)

Types to ignore in the arrays.

Returns
Expand All @@ -328,7 +333,27 @@ def _remove_non_arrays(*arrays, remove_none=True, remove_types=(str,)):
return filtered_arrays


def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None):
def _unwrap_memoryviewslices(*arrays):
# Since _cyutility._memoryviewslice is an implementation detail of the
# Cython runtime, we would rather not introduce a possibly brittle
# import statement to run `isinstance`-based filtering, hence the
# attribute-based type inspection.
unwrapped = []
for a in arrays:
a_type = type(a)
if (
a_type.__module__ == "_cyutility"
and a_type.__name__ == "_memoryviewslice"
and hasattr(a, "base")
):
a = a.base
unwrapped.append(a)
return unwrapped


def get_namespace(
*arrays, remove_none=True, remove_types=REMOVE_TYPES_DEFAULT, xp=None
):
"""Get namespace of arrays.

Introspect `arrays` arguments and return their common Array API compatible
Expand Down Expand Up @@ -364,7 +389,7 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None):
remove_none : bool, default=True
Whether to ignore None objects passed in arrays.

remove_types : tuple or list, default=(str,)
remove_types : tuple or list, (str, list, tuple)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
remove_types : tuple or list, (str, list, tuple)
remove_types : tuple or list, default=(str, list, tuple)

Types to ignore in the arrays.

xp : module, default=None
Expand Down Expand Up @@ -399,12 +424,19 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None):
remove_types=remove_types,
)

# get_namespace can be called by helper functions that are used both in
# array API compatible code and non-array API Cython related code. To
# support the latter on NumPy inputs without raising a TypeError, we
# unwrap potential Cython memoryview slices here.
arrays = _unwrap_memoryviewslices(*arrays)

if not arrays:
return np_compat, False

_check_array_api_dispatch(array_api_dispatch)

namespace, is_array_api_compliant = array_api_compat.get_namespace(*arrays), True
namespace = array_api_compat.get_namespace(*arrays)
is_array_api_compliant = True

if namespace.__name__ == "array_api_strict" and hasattr(
namespace, "set_array_api_strict_flags"
Expand All @@ -415,7 +447,7 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None):


def get_namespace_and_device(
*array_list, remove_none=True, remove_types=(str,), xp=None
*array_list, remove_none=True, remove_types=REMOVE_TYPES_DEFAULT, xp=None
):
"""Combination into one single function of `get_namespace` and `device`.

Expand All @@ -425,7 +457,7 @@ def get_namespace_and_device(
Array objects.
remove_none : bool, default=True
Whether to ignore None objects passed in arrays.
remove_types : tuple or list, default=(str,)
remove_types : tuple or list, (str, list, tuple)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
remove_types : tuple or list, (str, list, tuple)
remove_types : tuple or list, default=(str, list, tuple)

Types to ignore in the arrays.
xp : module, default=None
Precomputed array namespace module. When passed, typically from a caller
Expand Down
24 changes: 20 additions & 4 deletions sklearn/utils/_test_common/instance_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@
SparsePCA,
TruncatedSVD,
)
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.discriminant_analysis import (
LinearDiscriminantAnalysis,
QuadraticDiscriminantAnalysis,
)
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import (
AdaBoostClassifier,
Expand Down Expand Up @@ -79,6 +82,7 @@
SequentialFeatureSelector,
)
from sklearn.frozen import FrozenEstimator
from sklearn.impute import SimpleImputer
from sklearn.kernel_approximation import (
Nystroem,
PolynomialCountSketch,
Expand Down Expand Up @@ -559,11 +563,16 @@
dict(solver="lbfgs"),
],
},
GaussianMixture: {"check_dict_unchanged": dict(max_iter=5, n_init=2)},
GaussianMixture: {
"check_dict_unchanged": dict(max_iter=5, n_init=2),
"check_array_api_input": dict(
max_iter=5, n_init=2, init_params="random_from_data"
),
},
GaussianRandomProjection: {"check_dict_unchanged": dict(n_components=1)},
GraphicalLasso: {"check_array_api_input": dict(max_iter=5, alpha=1.0)},
IncrementalPCA: {"check_dict_unchanged": dict(batch_size=10, n_components=1)},
Isomap: {"check_dict_unchanged": dict(n_components=1)},
KMeans: {"check_dict_unchanged": dict(max_iter=5, n_clusters=1, n_init=2)},
# TODO(1.9) simplify when averaged_inverted_cdf is the default
KBinsDiscretizer: {
"check_sample_weight_equivalence_on_dense_data": [
Expand Down Expand Up @@ -595,7 +604,11 @@
strategy="quantile", quantile_method="averaged_inverted_cdf"
),
},
KernelPCA: {"check_dict_unchanged": dict(n_components=1)},
KernelPCA: {
"check_dict_unchanged": dict(n_components=1),
"check_array_api_input": dict(fit_inverse_transform=True),
},
KMeans: {"check_dict_unchanged": dict(max_iter=5, n_clusters=1, n_init=2)},
LassoLars: {"check_non_transformer_estimators_n_iter": dict(alpha=0.0)},
LatentDirichletAllocation: {
"check_dict_unchanged": dict(batch_size=10, max_iter=5, n_components=1)
Expand Down Expand Up @@ -693,6 +706,7 @@
dict(solver="highs-ipm"),
],
},
QuadraticDiscriminantAnalysis: {"check_array_api_input": dict(reg_param=1.0)},
RBFSampler: {"check_dict_unchanged": dict(n_components=1)},
Ridge: {
"check_sample_weight_equivalence_on_dense_data": [
Expand Down Expand Up @@ -720,7 +734,9 @@
],
},
SkewedChi2Sampler: {"check_dict_unchanged": dict(n_components=1)},
SimpleImputer: {"check_array_api_input": dict(add_indicator=True)},
SparseCoder: {
"check_array_api_input": dict(dictionary=rng.normal(size=(5, 10))),
"check_estimators_dtypes": dict(dictionary=rng.normal(size=(5, 5))),
"check_dtype_object": dict(dictionary=rng.normal(size=(5, 10))),
"check_transformers_unfitted_stateless": dict(
Expand Down
116 changes: 69 additions & 47 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,11 @@ def _yield_checks(estimator):
yield check_estimators_pickle
yield partial(check_estimators_pickle, readonly_memmap=True)

if tags.array_api_support:
for check in _yield_array_api_checks(estimator):
yield check
for check in _yield_array_api_checks(
estimator,
only_numpy=not tags.array_api_support,
):
yield check

yield check_f_contiguous_array_estimator

Expand Down Expand Up @@ -336,18 +338,30 @@ def _yield_outliers_checks(estimator):
yield check_non_transformer_estimators_n_iter


def _yield_array_api_checks(estimator):
for (
array_namespace,
device,
dtype_name,
) in yield_namespace_device_dtype_combinations():
def _yield_array_api_checks(estimator, only_numpy=False):
if only_numpy:
# Enabling array API dispatch and feeding the NumPy inputs should yield
# consistent results, even if estimator does not explicitly support
# array API.
yield partial(
check_array_api_input,
array_namespace=array_namespace,
dtype_name=dtype_name,
device=device,
array_namespace="numpy",
expect_only_array_outputs=False,
)
else:
# These extended checks should pass for all estimators that declare
# array API support in their tags.
for (
array_namespace,
device,
dtype_name,
) in yield_namespace_device_dtype_combinations():
yield partial(
check_array_api_input,
array_namespace=array_namespace,
dtype_name=dtype_name,
device=device,
)


def _yield_all_checks(estimator, legacy: bool):
Expand Down Expand Up @@ -1048,6 +1062,7 @@ def check_array_api_input(
dtype_name="float64",
check_values=False,
check_sample_weight=False,
expect_only_array_outputs=True,
):
"""Check that the estimator can work consistently with the Array API

Expand All @@ -1057,17 +1072,24 @@ def check_array_api_input(
When check_values is True, it also checks that calling the estimator on the
array_api Array gives the same results as ndarrays.

When sample_weight is True, dummy sample weights are passed to the fit call.
When check_sample_weight is True, dummy sample weights are passed to the
fit call.

When expect_only_array_outputs is False, the check accepts non-array
outputs from estimator methods (e.g., sparse data structures). This is
Comment on lines +1078 to +1079
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you expand on this? I was confused just reading this, but https://github.com/scikit-learn/scikit-learn/pull/32846/changes#r2602981393 cleared it up for me.

useful to test that enabling array API dispatch does not break the
estimator, even if the estimator does not support array API.
"""
xp = _array_api_for_tests(array_namespace, device)

X, y = make_classification(random_state=42)
X, y = make_classification(n_samples=30, n_features=10, random_state=42)
X = X.astype(dtype_name, copy=False)

X = _enforce_estimator_tags_X(estimator_orig, X)
y = _enforce_estimator_tags_y(estimator_orig, y)

est = clone(estimator_orig)
set_random_state(est)

X_xp = xp.asarray(X, device=device)
y_xp = xp.asarray(y, device=device)
Expand Down Expand Up @@ -1193,47 +1215,47 @@ def check_array_api_input(
f"got {result_ns}."
)

with config_context(array_api_dispatch=True):
assert array_device(result_xp) == array_device(X_xp)

result_xp_np = _convert_to_numpy(result_xp, xp=xp)
if expect_only_array_outputs:
with config_context(array_api_dispatch=True):
assert array_device(result_xp) == array_device(X_xp)

if check_values:
assert_allclose(
result,
result_xp_np,
err_msg=f"{method} did not the return the same result",
atol=_atol_for_type(X.dtype),
)
else:
if hasattr(result, "shape"):
result_xp_np = _convert_to_numpy(result_xp, xp=xp)
if check_values:
assert_allclose(
result,
result_xp_np,
err_msg=f"{method} did not the return the same result",
atol=_atol_for_type(X.dtype),
)
elif hasattr(result, "shape"):
assert result.shape == result_xp_np.shape
assert result.dtype == result_xp_np.dtype

if method_name == "transform" and hasattr(est, "inverse_transform"):
inverse_result = est.inverse_transform(result)
with config_context(array_api_dispatch=True):
invese_result_xp = est_xp.inverse_transform(result_xp)
inverse_result_ns = get_namespace(invese_result_xp)[0].__name__
assert inverse_result_ns == input_ns, (
"'inverse_transform' output is in wrong namespace, expected"
f" {input_ns}, got {inverse_result_ns}."
)

with config_context(array_api_dispatch=True):
assert array_device(invese_result_xp) == array_device(X_xp)

invese_result_xp_np = _convert_to_numpy(invese_result_xp, xp=xp)
if check_values:
assert_allclose(
inverse_result,
invese_result_xp_np,
err_msg="inverse_transform did not the return the same result",
atol=_atol_for_type(X.dtype),
inverse_result_xp = est_xp.inverse_transform(result_xp)

if expect_only_array_outputs:
with config_context(array_api_dispatch=True):
inverse_result_ns = get_namespace(inverse_result_xp)[0].__name__
assert inverse_result_ns == input_ns, (
"'inverse_transform' output is in wrong namespace, expected"
f" {input_ns}, got {inverse_result_ns}."
)
else:
assert inverse_result.shape == invese_result_xp_np.shape
assert inverse_result.dtype == invese_result_xp_np.dtype
assert array_device(inverse_result_xp) == array_device(X_xp)
Copy link
Member

@lesteve lesteve Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need a config_context here otherwise this assert statement will always pass, checking that None == None. This was noted some time ago by Stefanie in #31814.

Maybe we should do something about it because it's easy to fall into this trap. One potential idea was to add an assert_same_device that uses array_api_compat.device and does not depend on config_context(array_api_dispatch=True) (and a similar assert function like assert_same_namespace). Also maybe it would be nice to have a better name than device for our config_context-dependent device function since I find the clash with array_api_compat.device a tiny bit confusing.

Suggested change
assert array_device(inverse_result_xp) == array_device(X_xp)
with config_context(array_api_dispatch=True):
assert array_device(result_xp) == array_device(X_xp)


inverse_result_xp_np = _convert_to_numpy(inverse_result_xp, xp=xp)
if check_values:
assert_allclose(
inverse_result,
inverse_result_xp_np,
err_msg="inverse_transform did not the return the same result",
atol=_atol_for_type(X.dtype),
)
elif hasattr(result, "shape"):
assert inverse_result.shape == inverse_result_xp_np.shape
assert inverse_result.dtype == inverse_result_xp_np.dtype


def check_array_api_input_and_values(
Expand Down
Loading
Loading