-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
Add an integration test for the use of array API in a non-trivial pipeline #32873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| ( | ||
| RidgeClassifier(), | ||
| 3, | ||
| ["accuracy"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to add "roc_auc" or "average_precision" here, but they do not yet support array API at this time: #26024.
| # Ensure that all inputs are on the y_pred device and namespace. | ||
| y_true = move_to(y_true, xp=xp, device=device) | ||
| sample_weight = move_to(sample_weight, xp=xp, device=device) | ||
| if hasattr(multioutput, "shape"): | ||
| multioutput = move_to(multioutput, xp=xp, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just leave this to _check_reg_targets to handle?
I think this kind of test would have help identify some of the problem discovered during the release process earlier.
Note that it still fails because
RidgeandLinearDiscriminantAnalysisstill do not accept mixed-array inputs (see #28668).Not sure if we should fix that here or in a PR that introduces a common test for that particular problem.
EDIT: I started to push a fix for the regression case to investigate all the possibly cascading bugs to fix (hiding one after the other). For the regression case we need both #28668 and #32755 and this is enough.
For classification, it seems that there is extra work needed in the private scorer code on top of updating the classifier and the metrics functions themselves.