这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 69 additions & 15 deletions seaborn/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,31 +1498,56 @@ def estimate_statistic(self, estimator, ci, n_boot):
self.value_label = "{}({})".format(estimator.__name__,
self.value_label)

def draw_confints(self, ax, at_group, confint, colors, **kws):

kws.setdefault("lw", mpl.rcParams["lines.linewidth"] * 1.8)
def draw_confints(self,
ax, at_group,
confint,
colors,
errwidth=None,
capsize=None,
**kws):

if errwidth is not None:
kws.setdefault("lw", errwidth)
else:
kws.setdefault("lw", 1.8)

for at, (ci_low, ci_high), color in zip(at_group,
confint,
colors):
if self.orient == "v":
ax.plot([at, at], [ci_low, ci_high], color=color, **kws)
if capsize is not None:
ax.plot([at - capsize / 2, at + capsize / 2],
[ci_low, ci_low], color=color, **kws)
ax.plot([at - capsize / 2, at + capsize / 2],
[ci_high, ci_high], color=color, **kws)
else:
ax.plot([ci_low, ci_high], [at, at], color=color, **kws)
if capsize is not None:
ax.plot([ci_low, ci_low],
[at - capsize / 2, at + capsize / 2],
color=color, **kws)
ax.plot([ci_high, ci_high],
[at - capsize / 2, at + capsize / 2],
color=color, **kws)


class _BarPlotter(_CategoricalStatPlotter):
"""Show point estimates and confidence intervals with bars."""

def __init__(self, x, y, hue, data, order, hue_order,
estimator, ci, n_boot, units,
orient, color, palette, saturation, errcolor):
orient, color, palette, saturation, errcolor, errwidth=None,
capsize=None):
"""Initialize the plotter."""
self.establish_variables(x, y, hue, data, orient,
order, hue_order, units)
self.establish_colors(color, palette, saturation)
self.estimate_statistic(estimator, ci, n_boot)

self.errcolor = errcolor
self.errwidth = errwidth
self.capsize = capsize

def draw_bars(self, ax, kws):
"""Draw the bars onto `ax`."""
Expand All @@ -1538,7 +1563,12 @@ def draw_bars(self, ax, kws):

# Draw the confidence intervals
errcolors = [self.errcolor] * len(barpos)
self.draw_confints(ax, barpos, self.confint, errcolors)
self.draw_confints(ax,
barpos,
self.confint,
errcolors,
self.errwidth,
self.capsize)

else:

Expand All @@ -1554,7 +1584,12 @@ def draw_bars(self, ax, kws):
if self.confint.size:
confint = self.confint[:, j]
errcolors = [self.errcolor] * len(offpos)
self.draw_confints(ax, offpos, confint, errcolors)
self.draw_confints(ax,
offpos,
confint,
errcolors,
self.errwidth,
self.capsize)

def plot(self, ax, bar_kws):
"""Make the plot."""
Expand All @@ -1569,7 +1604,7 @@ class _PointPlotter(_CategoricalStatPlotter):
def __init__(self, x, y, hue, data, order, hue_order,
estimator, ci, n_boot, units,
markers, linestyles, dodge, join, scale,
orient, color, palette):
orient, color, palette, errwidth=None, capsize=None):
"""Initialize the plotter."""
self.establish_variables(x, y, hue, data, orient,
order, hue_order, units)
Expand Down Expand Up @@ -1602,6 +1637,8 @@ def __init__(self, x, y, hue, data, order, hue_order,
self.dodge = dodge
self.join = join
self.scale = scale
self.errwidth = errwidth
self.capsize = capsize

@property
def hue_offsets(self):
Expand Down Expand Up @@ -1634,8 +1671,8 @@ def draw_points(self, ax):
color=color, ls=ls, lw=lw)

# Draw the confidence intervals
self.draw_confints(ax, pointpos, self.confint, self.colors, lw=lw)

self.draw_confints(ax, pointpos, self.confint, self.colors,
self.errwidth, self.capsize)
# Draw the estimate points
marker = self.markers[0]
if self.orient == "h":
Expand Down Expand Up @@ -1675,7 +1712,8 @@ def draw_points(self, ax):
confint = self.confint[:, j]
errcolors = [self.colors[j]] * len(offpos)
self.draw_confints(ax, offpos, confint, errcolors,
zorder=z, lw=lw)
self.errwidth, self.capsize,
zorder=z)

# Draw the estimate points
marker = self.markers[j]
Expand Down Expand Up @@ -2025,6 +2063,16 @@ def plot(self, ax, boxplot_kws):
``1`` if you want the plot colors to perfectly match the input color
spec.\
"""),
capsize=dedent("""\
capsize : float, optional
Length of caps on confidence interval (drawn perpendicular to
primary line). If unspecified, no caps will be drawn.
Typical values are between 0.03 and 0.1.\
"""),
errwidth=dedent("""\
errwidth : float, optional
Thickness of lines drawn for the confidence interval (and caps).\
"""),
width=dedent("""\
width : float, optional
Width of a full element when not using hue nesting, or width of all the
Expand Down Expand Up @@ -2074,6 +2122,9 @@ def plot(self, ax, boxplot_kws):
lvplot=dedent("""\
lvplot : An extension of the boxplot for long-tailed and large data sets.
"""),



)

_categorical_docs.update(_facet_docs)
Expand Down Expand Up @@ -2831,7 +2882,7 @@ def swarmplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
def barplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
estimator=np.mean, ci=95, n_boot=1000, units=None,
orient=None, color=None, palette=None, saturation=.75,
errcolor=".26", ax=None, **kwargs):
errcolor=".26", errwidth=None, capsize=None, ax=None, **kwargs):

# Handle some deprecated arguments
if "hline" in kwargs:
Expand All @@ -2850,7 +2901,7 @@ def barplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
plotter = _BarPlotter(x, y, hue, data, order, hue_order,
estimator, ci, n_boot, units,
orient, color, palette, saturation,
errcolor)
errcolor, errwidth, capsize)

if ax is None:
ax = plt.gca()
Expand Down Expand Up @@ -2894,6 +2945,8 @@ def barplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
errcolor : matplotlib color
Color for the lines that represent the confidence interval.
{ax_in}
{errwidth}
{capsize}
kwargs : key, value mappings
Other keyword arguments are passed through to ``plt.bar`` at draw
time.
Expand Down Expand Up @@ -2989,7 +3042,8 @@ def barplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
def pointplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
estimator=np.mean, ci=95, n_boot=1000, units=None,
markers="o", linestyles="-", dodge=False, join=True, scale=1,
orient=None, color=None, palette=None, ax=None, **kwargs):
orient=None, color=None, palette=None, ax=None, errwidth=None,
capsize=None, **kwargs):

# Handle some deprecated arguments
if "hline" in kwargs:
Expand All @@ -3008,7 +3062,7 @@ def pointplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
plotter = _PointPlotter(x, y, hue, data, order, hue_order,
estimator, ci, n_boot, units,
markers, linestyles, dodge, join, scale,
orient, color, palette)
orient, color, palette, errwidth, capsize)

if ax is None:
ax = plt.gca()
Expand Down Expand Up @@ -3187,7 +3241,7 @@ def countplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
elif x is not None and y is not None:
raise TypeError("Cannot pass values for both `x` and `y`")
else:
raise TypeError("Must pass valus for either `x` or `y`")
raise TypeError("Must pass values for either `x` or `y`")

plotter = _BarPlotter(x, y, hue, data, order, hue_order,
estimator, ci, n_boot, units,
Expand Down
35 changes: 35 additions & 0 deletions seaborn/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,32 @@ def test_draw_cis(self):

plt.close("all")

# Test vertical CIs with endcaps
p.orient = "v"

f, ax = plt.subplots()
p.draw_confints(ax, at_group, confints, colors, capsize=0.3)
capline = ax.lines[len(ax.lines) - 1]
caplinestart = capline.get_xdata()[0]
caplineend = capline.get_xdata()[1]
caplinelength = abs(caplineend - caplinestart)
nt.assert_almost_equal(caplinelength, 0.3)
nt.assert_equal(len(ax.lines), 6)

plt.close("all")

# Test horizontal CIs with endcaps
p.orient = "h"

f, ax = plt.subplots()
p.draw_confints(ax, at_group, confints, colors, capsize=0.3)
capline = ax.lines[len(ax.lines) - 1]
caplinestart = capline.get_ydata()[0]
caplineend = capline.get_ydata()[1]
caplinelength = abs(caplineend - caplinestart)
nt.assert_almost_equal(caplinelength, 0.3)
nt.assert_equal(len(ax.lines), 6)

# Test extra keyword arguments
f, ax = plt.subplots()
p.draw_confints(ax, at_group, confints, colors, lw=4)
Expand All @@ -658,6 +684,15 @@ def test_draw_cis(self):

plt.close("all")

# Test errwidth is set appropriately
f, ax = plt.subplots()
p.draw_confints(ax, at_group, confints, colors, errwidth=2)
capline = ax.lines[len(ax.lines)-1]
nt.assert_equal(capline._linewidth, 2)
nt.assert_equal(len(ax.lines), 2)

plt.close("all")


class TestBoxPlotter(CategoricalFixture):

Expand Down