mirror of
synced 2025-01-10 14:43:19 +01:00
559 lines
18 KiB
559 lines
18 KiB
import os
import warnings
import numpy as np
from numpy import random
from pandas.util._decorators import cache_readonly
import pandas.util._test_decorators as td
from pandas.core.dtypes.api import is_list_like
import pandas as pd
from pandas import DataFrame, Series
import pandas._testing as tm
This is a common base class used for various plotting tests
class TestPlotBase:
def setup_method(self, method):
import matplotlib as mpl
from pandas.plotting._matplotlib import compat
self.mpl_ge_2_2_3 = compat._mpl_ge_2_2_3()
self.mpl_ge_3_0_0 = compat._mpl_ge_3_0_0()
self.mpl_ge_3_1_0 = compat._mpl_ge_3_1_0()
self.mpl_ge_3_2_0 = compat._mpl_ge_3_2_0()
self.bp_n_objects = 7
self.polycollection_factor = 2
self.default_figsize = (6.4, 4.8)
self.default_tick_position = "left"
n = 100
with tm.RNGContext(42):
gender = np.random.choice(["Male", "Female"], size=n)
classroom = np.random.choice(["A", "B", "C"], size=n)
self.hist_df = DataFrame(
"gender": gender,
"classroom": classroom,
"height": random.normal(66, 4, size=n),
"weight": random.normal(161, 32, size=n),
"category": random.randint(4, size=n),
self.tdf = tm.makeTimeDataFrame()
self.hexbin_df = DataFrame(
"A": np.random.uniform(size=20),
"B": np.random.uniform(size=20),
"C": np.arange(20) + np.random.uniform(size=20),
def teardown_method(self, method):
def plt(self):
import matplotlib.pyplot as plt
return plt
def colorconverter(self):
import matplotlib.colors as colors
return colors.colorConverter
def _check_legend_labels(self, axes, labels=None, visible=True):
Check each axes has expected legend labels
axes : matplotlib Axes object, or its list-like
labels : list-like
expected legend labels
visible : bool
expected legend visibility. labels are checked only when visible is
if visible and (labels is None):
raise ValueError("labels must be specified when visible is True")
axes = self._flatten_visible(axes)
for ax in axes:
if visible:
assert ax.get_legend() is not None
self._check_text_labels(ax.get_legend().get_texts(), labels)
assert ax.get_legend() is None
def _check_legend_marker(self, ax, expected_markers=None, visible=True):
Check ax has expected legend markers
ax : matplotlib Axes object
expected_markers : list-like
expected legend markers
visible : bool
expected legend visibility. labels are checked only when visible is
if visible and (expected_markers is None):
raise ValueError("Markers must be specified when visible is True")
if visible:
handles, _ = ax.get_legend_handles_labels()
markers = [handle.get_marker() for handle in handles]
assert markers == expected_markers
assert ax.get_legend() is None
def _check_data(self, xp, rs):
Check each axes has identical lines
xp : matplotlib Axes object
rs : matplotlib Axes object
xp_lines = xp.get_lines()
rs_lines = rs.get_lines()
def check_line(xpl, rsl):
xpdata = xpl.get_xydata()
rsdata = rsl.get_xydata()
tm.assert_almost_equal(xpdata, rsdata)
assert len(xp_lines) == len(rs_lines)
[check_line(xpl, rsl) for xpl, rsl in zip(xp_lines, rs_lines)]
def _check_visible(self, collections, visible=True):
Check each artist is visible or not
collections : matplotlib Artist or its list-like
target Artist or its list or collection
visible : bool
expected visibility
from matplotlib.collections import Collection
if not isinstance(collections, Collection) and not is_list_like(collections):
collections = [collections]
for patch in collections:
assert patch.get_visible() == visible
def _get_colors_mapped(self, series, colors):
unique = series.unique()
# unique and colors length can be differed
# depending on slice value
mapped = dict(zip(unique, colors))
return [mapped[v] for v in series.values]
def _check_colors(
self, collections, linecolors=None, facecolors=None, mapping=None
Check each artist has expected line colors and face colors
collections : list-like
list or collection of target artist
linecolors : list-like which has the same length as collections
list of expected line colors
facecolors : list-like which has the same length as collections
list of expected face colors
mapping : Series
Series used for color grouping key
used for andrew_curves, parallel_coordinates, radviz test
from matplotlib.collections import Collection, LineCollection, PolyCollection
from matplotlib.lines import Line2D
conv = self.colorconverter
if linecolors is not None:
if mapping is not None:
linecolors = self._get_colors_mapped(mapping, linecolors)
linecolors = linecolors[: len(collections)]
assert len(collections) == len(linecolors)
for patch, color in zip(collections, linecolors):
if isinstance(patch, Line2D):
result = patch.get_color()
# Line2D may contains string color expression
result = conv.to_rgba(result)
elif isinstance(patch, (PolyCollection, LineCollection)):
result = tuple(patch.get_edgecolor()[0])
result = patch.get_edgecolor()
expected = conv.to_rgba(color)
assert result == expected
if facecolors is not None:
if mapping is not None:
facecolors = self._get_colors_mapped(mapping, facecolors)
facecolors = facecolors[: len(collections)]
assert len(collections) == len(facecolors)
for patch, color in zip(collections, facecolors):
if isinstance(patch, Collection):
# returned as list of np.array
result = patch.get_facecolor()[0]
result = patch.get_facecolor()
if isinstance(result, np.ndarray):
result = tuple(result)
expected = conv.to_rgba(color)
assert result == expected
def _check_text_labels(self, texts, expected):
Check each text has expected labels
texts : matplotlib Text object, or its list-like
target text, or its list
expected : str or list-like which has the same length as texts
expected text label, or its list
if not is_list_like(texts):
assert texts.get_text() == expected
labels = [t.get_text() for t in texts]
assert len(labels) == len(expected)
for label, e in zip(labels, expected):
assert label == e
def _check_ticks_props(
self, axes, xlabelsize=None, xrot=None, ylabelsize=None, yrot=None
Check each axes has expected tick properties
axes : matplotlib Axes object, or its list-like
xlabelsize : number
expected xticks font size
xrot : number
expected xticks rotation
ylabelsize : number
expected yticks font size
yrot : number
expected yticks rotation
from matplotlib.ticker import NullFormatter
axes = self._flatten_visible(axes)
for ax in axes:
if xlabelsize is not None or xrot is not None:
if isinstance(ax.xaxis.get_minor_formatter(), NullFormatter):
# If minor ticks has NullFormatter, rot / fontsize are not
# retained
labels = ax.get_xticklabels()
labels = ax.get_xticklabels() + ax.get_xticklabels(minor=True)
for label in labels:
if xlabelsize is not None:
tm.assert_almost_equal(label.get_fontsize(), xlabelsize)
if xrot is not None:
tm.assert_almost_equal(label.get_rotation(), xrot)
if ylabelsize is not None or yrot is not None:
if isinstance(ax.yaxis.get_minor_formatter(), NullFormatter):
labels = ax.get_yticklabels()
labels = ax.get_yticklabels() + ax.get_yticklabels(minor=True)
for label in labels:
if ylabelsize is not None:
tm.assert_almost_equal(label.get_fontsize(), ylabelsize)
if yrot is not None:
tm.assert_almost_equal(label.get_rotation(), yrot)
def _check_ax_scales(self, axes, xaxis="linear", yaxis="linear"):
Check each axes has expected scales
axes : matplotlib Axes object, or its list-like
xaxis : {'linear', 'log'}
expected xaxis scale
yaxis : {'linear', 'log'}
expected yaxis scale
axes = self._flatten_visible(axes)
for ax in axes:
assert ax.xaxis.get_scale() == xaxis
assert ax.yaxis.get_scale() == yaxis
def _check_axes_shape(self, axes, axes_num=None, layout=None, figsize=None):
Check expected number of axes is drawn in expected layout
axes : matplotlib Axes object, or its list-like
axes_num : number
expected number of axes. Unnecessary axes should be set to
layout : tuple
expected layout, (expected number of rows , columns)
figsize : tuple
expected figsize. default is matplotlib default
from pandas.plotting._matplotlib.tools import _flatten
if figsize is None:
figsize = self.default_figsize
visible_axes = self._flatten_visible(axes)
if axes_num is not None:
assert len(visible_axes) == axes_num
for ax in visible_axes:
# check something drawn on visible axes
assert len(ax.get_children()) > 0
if layout is not None:
result = self._get_axes_layout(_flatten(axes))
assert result == layout
np.array(figsize, dtype=np.float64),
def _get_axes_layout(self, axes):
x_set = set()
y_set = set()
for ax in axes:
# check axes coordinates to estimate layout
points = ax.get_position().get_points()
return (len(y_set), len(x_set))
def _flatten_visible(self, axes):
Flatten axes, and filter only visible
axes : matplotlib Axes object, or its list-like
from pandas.plotting._matplotlib.tools import _flatten
axes = _flatten(axes)
axes = [ax for ax in axes if ax.get_visible()]
return axes
def _check_has_errorbars(self, axes, xerr=0, yerr=0):
Check axes has expected number of errorbars
axes : matplotlib Axes object, or its list-like
xerr : number
expected number of x errorbar
yerr : number
expected number of y errorbar
axes = self._flatten_visible(axes)
for ax in axes:
containers = ax.containers
xerr_count = 0
yerr_count = 0
for c in containers:
has_xerr = getattr(c, "has_xerr", False)
has_yerr = getattr(c, "has_yerr", False)
if has_xerr:
xerr_count += 1
if has_yerr:
yerr_count += 1
assert xerr == xerr_count
assert yerr == yerr_count
def _check_box_return_type(
self, returned, return_type, expected_keys=None, check_ax_title=True
Check box returned type is correct
returned : object to be tested, returned from boxplot
return_type : str
return_type passed to boxplot
expected_keys : list-like, optional
group labels in subplot case. If not passed,
the function checks assuming boxplot uses single ax
check_ax_title : bool
Whether to check the ax.title is the same as expected_key
Intended to be checked by calling from ``boxplot``.
Normal ``plot`` doesn't attach ``ax.title``, it must be disabled.
from matplotlib.axes import Axes
types = {"dict": dict, "axes": Axes, "both": tuple}
if expected_keys is None:
# should be fixed when the returning default is changed
if return_type is None:
return_type = "dict"
assert isinstance(returned, types[return_type])
if return_type == "both":
assert isinstance(returned.ax, Axes)
assert isinstance(returned.lines, dict)
# should be fixed when the returning default is changed
if return_type is None:
for r in self._flatten_visible(returned):
assert isinstance(r, Axes)
assert isinstance(returned, Series)
assert sorted(returned.keys()) == sorted(expected_keys)
for key, value in returned.items():
assert isinstance(value, types[return_type])
# check returned dict has correct mapping
if return_type == "axes":
if check_ax_title:
assert value.get_title() == key
elif return_type == "both":
if check_ax_title:
assert value.ax.get_title() == key
assert isinstance(value.ax, Axes)
assert isinstance(value.lines, dict)
elif return_type == "dict":
line = value["medians"][0]
axes = line.axes
if check_ax_title:
assert axes.get_title() == key
raise AssertionError
def _check_grid_settings(self, obj, kinds, kws={}):
# Make sure plot defaults to rcParams['axes.grid'] setting, GH 9792
import matplotlib as mpl
def is_grid_on():
xticks = self.plt.gca().xaxis.get_major_ticks()
yticks = self.plt.gca().yaxis.get_major_ticks()
# for mpl 2.2.2, gridOn and gridline.get_visible disagree.
# for new MPL, they are the same.
if self.mpl_ge_3_1_0:
xoff = all(not g.gridline.get_visible() for g in xticks)
yoff = all(not g.gridline.get_visible() for g in yticks)
xoff = all(not g.gridOn for g in xticks)
yoff = all(not g.gridOn for g in yticks)
return not (xoff and yoff)
spndx = 1
for kind in kinds:
self.plt.subplot(1, 4 * len(kinds), spndx)
spndx += 1
mpl.rc("axes", grid=False)
obj.plot(kind=kind, **kws)
assert not is_grid_on()
self.plt.subplot(1, 4 * len(kinds), spndx)
spndx += 1
mpl.rc("axes", grid=True)
obj.plot(kind=kind, grid=False, **kws)
assert not is_grid_on()
if kind != "pie":
self.plt.subplot(1, 4 * len(kinds), spndx)
spndx += 1
mpl.rc("axes", grid=True)
obj.plot(kind=kind, **kws)
assert is_grid_on()
self.plt.subplot(1, 4 * len(kinds), spndx)
spndx += 1
mpl.rc("axes", grid=False)
obj.plot(kind=kind, grid=True, **kws)
assert is_grid_on()
def _unpack_cycler(self, rcParams, field="color"):
Auxiliary function for correctly unpacking cycler after MPL >= 1.5
return [v[field] for v in rcParams["axes.prop_cycle"]]
def _check_plot_works(f, filterwarnings="always", **kwargs):
import matplotlib.pyplot as plt
ret = None
with warnings.catch_warnings():
fig = kwargs["figure"]
except KeyError:
fig = plt.gcf()
kwargs.get("ax", fig.add_subplot(211))
ret = f(**kwargs)
if f is pd.plotting.bootstrap_plot:
assert "ax" not in kwargs
kwargs["ax"] = fig.add_subplot(212)
ret = f(**kwargs)
with tm.ensure_clean(return_filelike=True) as path:
return ret
def curpath():
pth, _ = os.path.split(os.path.abspath(__file__))
return pth