diff --git a/seaborn/categorical.py b/seaborn/categorical.py index 49b56729aa..a2a273f81b 100644 --- a/seaborn/categorical.py +++ b/seaborn/categorical.py @@ -1498,24 +1498,47 @@ def estimate_statistic(self, estimator, ci, n_boot): self.value_label = "{}({})".format(estimator.__name__, self.value_label) - def draw_confints(self, ax, at_group, confint, colors, **kws): - - kws.setdefault("lw", mpl.rcParams["lines.linewidth"] * 1.8) + def draw_confints(self, + ax, at_group, + confint, + colors, + errwidth=None, + capsize=None, + **kws): + + if errwidth is not None: + kws.setdefault("lw", errwidth) + else: + kws.setdefault("lw", 1.8) for at, (ci_low, ci_high), color in zip(at_group, confint, colors): if self.orient == "v": ax.plot([at, at], [ci_low, ci_high], color=color, **kws) + if capsize is not None: + ax.plot([at - capsize / 2, at + capsize / 2], + [ci_low, ci_low], color=color, **kws) + ax.plot([at - capsize / 2, at + capsize / 2], + [ci_high, ci_high], color=color, **kws) else: ax.plot([ci_low, ci_high], [at, at], color=color, **kws) + if capsize is not None: + ax.plot([ci_low, ci_low], + [at - capsize / 2, at + capsize / 2], + color=color, **kws) + ax.plot([ci_high, ci_high], + [at - capsize / 2, at + capsize / 2], + color=color, **kws) class _BarPlotter(_CategoricalStatPlotter): """Show point estimates and confidence intervals with bars.""" + def __init__(self, x, y, hue, data, order, hue_order, estimator, ci, n_boot, units, - orient, color, palette, saturation, errcolor): + orient, color, palette, saturation, errcolor, errwidth=None, + capsize=None): """Initialize the plotter.""" self.establish_variables(x, y, hue, data, orient, order, hue_order, units) @@ -1523,6 +1546,8 @@ def __init__(self, x, y, hue, data, order, hue_order, self.estimate_statistic(estimator, ci, n_boot) self.errcolor = errcolor + self.errwidth = errwidth + self.capsize = capsize def draw_bars(self, ax, kws): """Draw the bars onto `ax`.""" @@ -1538,7 +1563,12 @@ def draw_bars(self, ax, kws): # Draw the confidence intervals errcolors = [self.errcolor] * len(barpos) - self.draw_confints(ax, barpos, self.confint, errcolors) + self.draw_confints(ax, + barpos, + self.confint, + errcolors, + self.errwidth, + self.capsize) else: @@ -1554,7 +1584,12 @@ def draw_bars(self, ax, kws): if self.confint.size: confint = self.confint[:, j] errcolors = [self.errcolor] * len(offpos) - self.draw_confints(ax, offpos, confint, errcolors) + self.draw_confints(ax, + offpos, + confint, + errcolors, + self.errwidth, + self.capsize) def plot(self, ax, bar_kws): """Make the plot.""" @@ -1569,7 +1604,7 @@ class _PointPlotter(_CategoricalStatPlotter): def __init__(self, x, y, hue, data, order, hue_order, estimator, ci, n_boot, units, markers, linestyles, dodge, join, scale, - orient, color, palette): + orient, color, palette, errwidth=None, capsize=None): """Initialize the plotter.""" self.establish_variables(x, y, hue, data, orient, order, hue_order, units) @@ -1602,6 +1637,8 @@ def __init__(self, x, y, hue, data, order, hue_order, self.dodge = dodge self.join = join self.scale = scale + self.errwidth = errwidth + self.capsize = capsize @property def hue_offsets(self): @@ -1634,8 +1671,8 @@ def draw_points(self, ax): color=color, ls=ls, lw=lw) # Draw the confidence intervals - self.draw_confints(ax, pointpos, self.confint, self.colors, lw=lw) - + self.draw_confints(ax, pointpos, self.confint, self.colors, + self.errwidth, self.capsize) # Draw the estimate points marker = self.markers[0] if self.orient == "h": @@ -1675,7 +1712,8 @@ def draw_points(self, ax): confint = self.confint[:, j] errcolors = [self.colors[j]] * len(offpos) self.draw_confints(ax, offpos, confint, errcolors, - zorder=z, lw=lw) + self.errwidth, self.capsize, + zorder=z) # Draw the estimate points marker = self.markers[j] @@ -2025,6 +2063,16 @@ def plot(self, ax, boxplot_kws): ``1`` if you want the plot colors to perfectly match the input color spec.\ """), + capsize=dedent("""\ + capsize : float, optional + Length of caps on confidence interval (drawn perpendicular to + primary line). If unspecified, no caps will be drawn. + Typical values are between 0.03 and 0.1.\ + """), + errwidth=dedent("""\ + errwidth : float, optional + Thickness of lines drawn for the confidence interval (and caps).\ + """), width=dedent("""\ width : float, optional Width of a full element when not using hue nesting, or width of all the @@ -2074,6 +2122,9 @@ def plot(self, ax, boxplot_kws): lvplot=dedent("""\ lvplot : An extension of the boxplot for long-tailed and large data sets. """), + + + ) _categorical_docs.update(_facet_docs) @@ -2831,7 +2882,7 @@ def swarmplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, def barplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, estimator=np.mean, ci=95, n_boot=1000, units=None, orient=None, color=None, palette=None, saturation=.75, - errcolor=".26", ax=None, **kwargs): + errcolor=".26", errwidth=None, capsize=None, ax=None, **kwargs): # Handle some deprecated arguments if "hline" in kwargs: @@ -2850,7 +2901,7 @@ def barplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, plotter = _BarPlotter(x, y, hue, data, order, hue_order, estimator, ci, n_boot, units, orient, color, palette, saturation, - errcolor) + errcolor, errwidth, capsize) if ax is None: ax = plt.gca() @@ -2894,6 +2945,8 @@ def barplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, errcolor : matplotlib color Color for the lines that represent the confidence interval. {ax_in} + {errwidth} + {capsize} kwargs : key, value mappings Other keyword arguments are passed through to ``plt.bar`` at draw time. @@ -2989,7 +3042,8 @@ def barplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, def pointplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, estimator=np.mean, ci=95, n_boot=1000, units=None, markers="o", linestyles="-", dodge=False, join=True, scale=1, - orient=None, color=None, palette=None, ax=None, **kwargs): + orient=None, color=None, palette=None, ax=None, errwidth=None, + capsize=None, **kwargs): # Handle some deprecated arguments if "hline" in kwargs: @@ -3008,7 +3062,7 @@ def pointplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, plotter = _PointPlotter(x, y, hue, data, order, hue_order, estimator, ci, n_boot, units, markers, linestyles, dodge, join, scale, - orient, color, palette) + orient, color, palette, errwidth, capsize) if ax is None: ax = plt.gca() @@ -3187,7 +3241,7 @@ def countplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, elif x is not None and y is not None: raise TypeError("Cannot pass values for both `x` and `y`") else: - raise TypeError("Must pass valus for either `x` or `y`") + raise TypeError("Must pass values for either `x` or `y`") plotter = _BarPlotter(x, y, hue, data, order, hue_order, estimator, ci, n_boot, units, diff --git a/seaborn/tests/test_categorical.py b/seaborn/tests/test_categorical.py index 3c0f921b90..be54281a5d 100644 --- a/seaborn/tests/test_categorical.py +++ b/seaborn/tests/test_categorical.py @@ -650,6 +650,32 @@ def test_draw_cis(self): plt.close("all") + # Test vertical CIs with endcaps + p.orient = "v" + + f, ax = plt.subplots() + p.draw_confints(ax, at_group, confints, colors, capsize=0.3) + capline = ax.lines[len(ax.lines) - 1] + caplinestart = capline.get_xdata()[0] + caplineend = capline.get_xdata()[1] + caplinelength = abs(caplineend - caplinestart) + nt.assert_almost_equal(caplinelength, 0.3) + nt.assert_equal(len(ax.lines), 6) + + plt.close("all") + + # Test horizontal CIs with endcaps + p.orient = "h" + + f, ax = plt.subplots() + p.draw_confints(ax, at_group, confints, colors, capsize=0.3) + capline = ax.lines[len(ax.lines) - 1] + caplinestart = capline.get_ydata()[0] + caplineend = capline.get_ydata()[1] + caplinelength = abs(caplineend - caplinestart) + nt.assert_almost_equal(caplinelength, 0.3) + nt.assert_equal(len(ax.lines), 6) + # Test extra keyword arguments f, ax = plt.subplots() p.draw_confints(ax, at_group, confints, colors, lw=4) @@ -658,6 +684,15 @@ def test_draw_cis(self): plt.close("all") + # Test errwidth is set appropriately + f, ax = plt.subplots() + p.draw_confints(ax, at_group, confints, colors, errwidth=2) + capline = ax.lines[len(ax.lines)-1] + nt.assert_equal(capline._linewidth, 2) + nt.assert_equal(len(ax.lines), 2) + + plt.close("all") + class TestBoxPlotter(CategoricalFixture):