Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- :class:`inspection.DecisionBoundaryDisplay` now displays distinct colors for all classes when using ``plot_method="contour"`` with ``response_method="predict"``, and when using ``plot_method="contourf"`` with ``response_method="predict_proba"`` or ``"decision_function"``. By :user:`Levente Csibi <leweex95>`. :pr:`32867`
18 changes: 18 additions & 0 deletions sklearn/inspection/_plot/decision_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,21 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar

plot_func = getattr(ax, plot_method)
if self.response.ndim == 2:
# For discrete responses (e.g., from predict), ensure all levels are used
# to display distinct colors for each class
if plot_method in ("contour", "contourf") and np.issubdtype(
self.response.dtype, np.integer
):
unique_levels = np.unique(self.response)
if plot_method == "contourf":
levels = np.concatenate(
[unique_levels - 0.5, [unique_levels.max() + 0.5]]
)
else:
levels = unique_levels

if "levels" not in kwargs:
kwargs["levels"] = levels
self.surface_ = plot_func(self.xx0, self.xx1, self.response, **kwargs)
else: # self.response.ndim == 3
n_responses = self.response.shape[-1]
Expand Down Expand Up @@ -260,6 +275,9 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar
if plot_method == "contour":
# Plot only argmax map for contour
class_map = self.response.argmax(axis=2)
# Ensure levels match the number of classes for distinct colors
if "levels" not in kwargs:
kwargs["levels"] = np.unique(class_map)
self.surface_ = plot_func(
self.xx0, self.xx1, class_map, colors=colors, **kwargs
)
Expand Down
45 changes: 45 additions & 0 deletions sklearn/inspection/_plot/tests/test_boundary_decision_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,3 +708,48 @@ class SubclassOfDisplay(DecisionBoundaryDisplay):
curve = SubclassOfDisplay.from_estimator(estimator=clf, X=X)

assert isinstance(curve, SubclassOfDisplay)


@pytest.mark.parametrize("plot_method", ["contourf", "contour"])
@pytest.mark.parametrize(
"response_method", ["predict", "predict_proba", "decision_function"]
)
@pytest.mark.parametrize("y_type", ["int", "str"])
def test_decision_boundary_display_many_classes(
pyplot, plot_method, response_method, y_type
):
"""Check that contour plots use all levels for classifiers with many classes.

Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/32866
"""
# Create a dataset with 11 classes
pts = np.array(
[
[-1, -1],
[-2, -1],
[1, 1],
[2, 1],
[2, 2],
[3, 2],
[3, 3],
[4, 3],
[4, 4],
[5, 4],
[5, 5],
]
)
if y_type == "int":
y = np.arange(11)
else:
y = [str(i) for i in range(11)]
clf = LogisticRegression().fit(pts, y)

disp = DecisionBoundaryDisplay.from_estimator(
clf, pts, response_method=response_method, plot_method=plot_method
)

# Check that the surface has levels for all classes when applicable
if hasattr(disp.surface_, "levels"):
expected_levels = 12 if plot_method == "contourf" else 11
assert len(disp.surface_.levels) == expected_levels
Loading