From 8a18775bbc4b316fb41d23206e0a2888aded989b Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 31 May 2022 21:01:07 -0400 Subject: [PATCH 1/7] Improve legend for categorical scatterplots --- seaborn/categorical.py | 61 +++++++++++++++++++++++------------------- seaborn/relational.py | 2 +- 2 files changed, 35 insertions(+), 28 deletions(-) diff --git a/seaborn/categorical.py b/seaborn/categorical.py index 15c3e8b09e..ca2a6efe22 100644 --- a/seaborn/categorical.py +++ b/seaborn/categorical.py @@ -18,17 +18,17 @@ import matplotlib.patches as Patches import matplotlib.pyplot as plt -from ._oldcore import ( - VectorPlotter, +from seaborn._oldcore import ( variable_type, infer_orient, categorical_order, ) -from . import utils -from .utils import remove_na, _normal_quantile_func, _draw_figure, _default_color -from .algorithms import bootstrap -from .palettes import color_palette, husl_palette, light_palette, dark_palette -from .axisgrid import FacetGrid, _facet_docs +from seaborn.relational import _RelationalPlotter +from seaborn import utils +from seaborn.utils import remove_na, _normal_quantile_func, _draw_figure, _default_color +from seaborn.algorithms import bootstrap +from seaborn.palettes import color_palette, husl_palette, light_palette, dark_palette +from seaborn.axisgrid import FacetGrid, _facet_docs __all__ = [ @@ -39,13 +39,16 @@ ] -class _CategoricalPlotterNew(VectorPlotter): +class _CategoricalPlotterNew(_RelationalPlotter): semantics = "x", "y", "hue", "units" wide_structure = {"x": "@columns", "y": "@values", "hue": "@columns"} flat_structure = {"x": "@index", "y": "@values"} + _legend_func = "scatter" + _legend_attributes = ["color"] + def __init__( self, data=None, @@ -53,6 +56,7 @@ def __init__( order=None, orient=None, require_numeric=False, + legend="auto", ): super().__init__(data=data, variables=variables) @@ -101,6 +105,8 @@ def __init__( cat_levels = categorical_order(self.plot_data[self.cat_axis], order) self.var_levels[self.cat_axis] = cat_levels + self.legend = legend + def _hue_backcompat(self, color, palette, hue_order, force_hue=False): """Implement backwards compatibility for hue parametrization. @@ -272,13 +278,13 @@ def plot_strips( else: points.set_edgecolors(edgecolor) - # TODO XXX fully implement legend - show_legend = not self._redundant_hue and self.input_format != "wide" - if "hue" in self.variables and show_legend: - for level in self._hue_map.levels: - color = self._hue_map(level) - ax.scatter([], [], s=60, color=mpl.colors.rgb2hex(color), label=level) - ax.legend(loc="best", title=self.variables["hue"]) + # Finalize the axes details + self._add_axis_labels(ax) + if self.legend and not self._redundant_hue and self.input_format != "wide": + self.add_legend_data(ax) + handles, _ = ax.get_legend_handles_labels() + if handles: + ax.legend(title=self.legend_title) def plot_swarms( self, @@ -361,13 +367,13 @@ def draw(points, renderer, *, center=center): _draw_figure(ax.figure) - # TODO XXX fully implement legend - show_legend = not self._redundant_hue and self.input_format != "wide" - if "hue" in self.variables and show_legend: # TODO and legend: - for level in self._hue_map.levels: - color = self._hue_map(level) - ax.scatter([], [], s=60, color=mpl.colors.rgb2hex(color), label=level) - ax.legend(loc="best", title=self.variables["hue"]) + # Finalize the axes details + self._add_axis_labels(ax) + if self.legend and not self._redundant_hue and self.input_format != "wide": + self.add_legend_data(ax) + handles, _ = ax.get_legend_handles_labels() + if handles: + ax.legend(title=self.legend_title) class _CategoricalFacetPlotter(_CategoricalPlotterNew): @@ -2747,18 +2753,17 @@ def stripplot( data=None, *, x=None, y=None, hue=None, order=None, hue_order=None, jitter=True, dodge=False, orient=None, color=None, palette=None, size=5, edgecolor="gray", linewidth=0, ax=None, - hue_norm=None, native_scale=False, formatter=None, + hue_norm=None, native_scale=False, formatter=None, legend="auto", **kwargs ): - # TODO XXX we need to add a legend= param!!! - p = _CategoricalPlotterNew( data=data, variables=_CategoricalPlotterNew.get_semantics(locals()), order=order, orient=orient, require_numeric=False, + legend=legend, ) if ax is None: @@ -2869,7 +2874,7 @@ def swarmplot( data=None, *, x=None, y=None, hue=None, order=None, hue_order=None, dodge=False, orient=None, color=None, palette=None, size=5, edgecolor="gray", linewidth=0, ax=None, - hue_norm=None, native_scale=False, formatter=None, warn_thresh=.05, + hue_norm=None, native_scale=False, formatter=None, legend="auto", warn_thresh=.05, **kwargs ): @@ -2879,6 +2884,7 @@ def swarmplot( order=order, orient=orient, require_numeric=False, + legend=legend, ) if ax is None: @@ -3548,7 +3554,7 @@ def catplot( units=None, seed=None, order=None, hue_order=None, row_order=None, col_order=None, kind="strip", height=5, aspect=1, orient=None, color=None, palette=None, - legend=True, legend_out=True, sharex=True, sharey=True, + legend="auto", legend_out=True, sharex=True, sharey=True, margin_titles=False, facet_kws=None, hue_norm=None, native_scale=False, formatter=None, **kwargs @@ -3587,6 +3593,7 @@ def catplot( order=order, orient=orient, require_numeric=False, + legend=legend, ) # XXX Copying a fair amount from displot, which is not ideal diff --git a/seaborn/relational.py b/seaborn/relational.py index 1ac2f3c93a..f6e376d8b3 100644 --- a/seaborn/relational.py +++ b/seaborn/relational.py @@ -524,7 +524,7 @@ def __init__( legend=None ): - # TODO this is messy, we want the mapping to be agnoistic about + # TODO this is messy, we want the mapping to be agnostic about # the kind of plot to draw, but for the time being we need to set # this information so the SizeMapping can use it self._default_size_range = ( From 69e5f2aacd780ff2fd234ad0e2db7786206d1a39 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 31 May 2022 21:07:51 -0400 Subject: [PATCH 2/7] Move legend attribute assignment to fix empty plot --- seaborn/categorical.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/seaborn/categorical.py b/seaborn/categorical.py index ca2a6efe22..a1fbda1c12 100644 --- a/seaborn/categorical.py +++ b/seaborn/categorical.py @@ -79,6 +79,8 @@ def __init__( require_numeric=require_numeric, ) + self.legend = legend + # Short-circuit in the case of an empty plot if not self.has_xy_data: return @@ -105,8 +107,6 @@ def __init__( cat_levels = categorical_order(self.plot_data[self.cat_axis], order) self.var_levels[self.cat_axis] = cat_levels - self.legend = legend - def _hue_backcompat(self, color, palette, hue_order, force_hue=False): """Implement backwards compatibility for hue parametrization. From 839298f34cbac5754cf5733294767bc98c06660a Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 31 May 2022 21:33:33 -0400 Subject: [PATCH 3/7] Don't create axis labels inside plotting functions --- seaborn/categorical.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/seaborn/categorical.py b/seaborn/categorical.py index a1fbda1c12..f494a2adde 100644 --- a/seaborn/categorical.py +++ b/seaborn/categorical.py @@ -279,7 +279,6 @@ def plot_strips( points.set_edgecolors(edgecolor) # Finalize the axes details - self._add_axis_labels(ax) if self.legend and not self._redundant_hue and self.input_format != "wide": self.add_legend_data(ax) handles, _ = ax.get_legend_handles_labels() @@ -368,7 +367,6 @@ def draw(points, renderer, *, center=center): _draw_figure(ax.figure) # Finalize the axes details - self._add_axis_labels(ax) if self.legend and not self._redundant_hue and self.input_format != "wide": self.add_legend_data(ax) handles, _ = ax.get_legend_handles_labels() From 450b19705295c431d7e8eae350934b0f07be6be3 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 31 May 2022 23:02:19 -0400 Subject: [PATCH 4/7] Add slight hack to enable catplot with empty x/y vectors --- seaborn/categorical.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/seaborn/categorical.py b/seaborn/categorical.py index f494a2adde..bca818ea7e 100644 --- a/seaborn/categorical.py +++ b/seaborn/categorical.py @@ -39,6 +39,8 @@ ] +# Subclassing _RelationalPlotter for the legend machinery, +# but probably should move that more centrally class _CategoricalPlotterNew(_RelationalPlotter): semantics = "x", "y", "hue", "units" @@ -3582,7 +3584,6 @@ def catplot( refactored_kinds = [ "strip", "swarm", ] - if kind in refactored_kinds: p = _CategoricalFacetPlotter( @@ -3620,12 +3621,17 @@ def catplot( **facet_kws, ) + # Capture this here because scale_categorical is going to insert a (null) + # x variable even if it is empty. It's not clear whether that needs to + # happen or if disabling that is the cleaner solution. + has_xy_data = p.has_xy_data + if not native_scale or p.var_types[p.cat_axis] == "categorical": p.scale_categorical(p.cat_axis, order=order, formatter=formatter) p._attach(g) - if not p.has_xy_data: + if not has_xy_data: return g palette, hue_order = p._hue_backcompat(color, palette, hue_order) From 37ab504965125bb99687412f11a7d65d6dcb3c0e Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Wed, 1 Jun 2022 06:45:56 -0400 Subject: [PATCH 5/7] Don't set axis limits for empty categorical plot --- seaborn/categorical.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/seaborn/categorical.py b/seaborn/categorical.py index bca818ea7e..58ac35d76d 100644 --- a/seaborn/categorical.py +++ b/seaborn/categorical.py @@ -180,6 +180,12 @@ def _adjust_cat_axis(self, ax, axis): if self.var_types[axis] != "categorical": return + # If both x/y data are empty, the correct way to set up the plot is + # somewhat undefined; because we don't add null category data to the plot in + # this case we don't *have* a categorical axis (yet), so best to just bail. + if self.plot_data[axis].empty: + return + # We can infer the total number of categories (including those from previous # plots that are not part of the plot we are currently making) from the number # of ticks, which matplotlib sets up while doing unit conversion. This feels From cef36f3438f503411578a8100e7b9cc583e97d9b Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Wed, 1 Jun 2022 07:45:19 -0400 Subject: [PATCH 6/7] Avoid expensive and uncessary computation when stripplot is not dodged --- seaborn/categorical.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/seaborn/categorical.py b/seaborn/categorical.py index 58ac35d76d..5740998f1d 100644 --- a/seaborn/categorical.py +++ b/seaborn/categorical.py @@ -262,8 +262,7 @@ def plot_strips( for sub_vars, sub_data in self.iter_data(iter_vars, from_comp_data=True, allow_empty=True): - - if offsets is not None: + if offsets is not None and (offsets != 0).any(): dodge_move = offsets[sub_data["hue"].map(self._hue_map.levels.index)] jitter_move = jitterer(size=len(sub_data)) if len(sub_data) > 1 else 0 From d2913f2a26456b09c7b3103091cf05ee258772af Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Wed, 1 Jun 2022 20:29:04 -0400 Subject: [PATCH 7/7] Add tests --- seaborn/categorical.py | 14 ++++++++++++-- seaborn/tests/test_categorical.py | 23 ++++++++++++++++++++--- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/seaborn/categorical.py b/seaborn/categorical.py index 5740998f1d..c32acd0ad1 100644 --- a/seaborn/categorical.py +++ b/seaborn/categorical.py @@ -286,7 +286,12 @@ def plot_strips( points.set_edgecolors(edgecolor) # Finalize the axes details - if self.legend and not self._redundant_hue and self.input_format != "wide": + if self.legend == "auto": + show_legend = not self._redundant_hue and self.input_format != "wide" + else: + show_legend = bool(self.legend) + + if show_legend: self.add_legend_data(ax) handles, _ = ax.get_legend_handles_labels() if handles: @@ -374,7 +379,12 @@ def draw(points, renderer, *, center=center): _draw_figure(ax.figure) # Finalize the axes details - if self.legend and not self._redundant_hue and self.input_format != "wide": + if self.legend == "auto": + show_legend = not self._redundant_hue and self.input_format != "wide" + else: + show_legend = bool(self.legend) + + if show_legend: self.add_legend_data(ax) handles, _ = ax.get_legend_handles_labels() if handles: diff --git a/seaborn/tests/test_categorical.py b/seaborn/tests/test_categorical.py index fc7fc5e571..aa7317525c 100644 --- a/seaborn/tests/test_categorical.py +++ b/seaborn/tests/test_categorical.py @@ -2028,6 +2028,24 @@ def test_three_points(self): for point_color in ax.collections[0].get_facecolor(): assert tuple(point_color) == to_rgba("C0") + def test_legend_categorical(self, long_df): + + ax = self.func(data=long_df, x="y", y="a", hue="b") + legend_texts = [t.get_text() for t in ax.legend_.texts] + expected = categorical_order(long_df["b"]) + assert legend_texts == expected + + def test_legend_numeric(self, long_df): + + ax = self.func(data=long_df, x="y", y="a", hue="z") + vals = [float(t.get_text()) for t in ax.legend_.texts] + assert (vals[1] - vals[0]) == pytest.approx(vals[2] - vals[1]) + + def test_legend_disabled(self, long_df): + + ax = self.func(data=long_df, x="y", y="a", hue="b", legend=False) + assert ax.legend_ is None + def test_palette_from_color_deprecation(self, long_df): color = (.9, .4, .5) @@ -2085,9 +2103,8 @@ def test_log_scale(self): dict(data="wide", orient="h"), dict(data="long", x="x", color="C3"), dict(data="long", y="y", hue="a", jitter=False), - # TODO XXX full numeric hue legend crashes pinned mpl, disabling for now - # dict(data="long", x="a", y="y", hue="z", edgecolor="w", linewidth=.5), - # dict(data="long", x="a_cat", y="y", hue="z"), + dict(data="long", x="a", y="y", hue="z", edgecolor="w", linewidth=.5), + dict(data="long", x="a_cat", y="y", hue="z"), dict(data="long", x="y", y="s", hue="c", orient="h", dodge=True), dict(data="long", x="s", y="y", hue="c", native_scale=True), ]