Source code for seaborn_extensions.swarmboxenplot

"""
A type of plot that combines swarms and box(en)/bar plots in an overlaid fashion.
"""

import typing as tp
import itertools
import warnings

import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import pingouin as pg
from tqdm import tqdm as _tqdm

from seaborn_extensions.types import DataFrame, Axis, Figure, Iterables
from seaborn_extensions.utils import get_grid_dims, filter_kwargs_by_callable


"""
import pandas as pd
import pingouin as pg
from seaborn_extensions import swarmboxenplot
# Demo with various tests available in Pingouin:
data = pd.DataFrame(
    {"cont": np.random.random(20), "cat": np.random.choice(["a", "b"], 20)}
)
data.loc[data["cat"] == "b", "cont"] *= 5
fig, stats = swarmboxenplot(data=data, x='cat', y='cont')
data['h'] = ['cl_1'] * 10 + ['cl_2'] * 10
fig, stats = swarmboxenplot(data=data, x='cat', y='cont', hue='h')

data['cont1'] = data['cont'] + np.random.random(20)
data['cont2'] = data['cont'] + np.random.random(20)
fig, stats = swarmboxenplot(data=data, x='cat', y=['cont1', 'cont2'], hue='h')

x = 'cat'
y = 'cont'
pg.ttest(*data.groupby(x)[y].apply(lambda x: list(x)))
pg.mwu(*data.groupby(x)[y].apply(lambda x: list(x)))
pg.kruskal(data=data, between=x, dv=y)
pg.pairwise_ttests(data=data, between=x, dv=y, parametric=True)  # same as T-test
pg.pairwise_ttests(data=data, between=x, dv=y, parametric=False)  # same as MWU

c = data[x].astype(pd.CategoricalDtype())
pg.linear_regression(pd.concat((c[col].cat.codes.rename(col) for col in c), axis=1), data[y])
pg.linear_regression(pd.get_dummies(c), data[y])
pg.logistic_regression(data[y], c.cat.codes)
"""


[docs]def swarmboxenplot( data: DataFrame, x: str, y: tp.Union[str, Iterables], hue: tp.Optional[str] = None, swarm: bool = True, boxen: bool = True, bar: bool = False, orient: str = "vertical", plot: bool = True, ax: tp.Union[Axis, tp.Sequence[Axis]] = None, test: tp.Union[bool, str] = "mann-whitney", to_test: str = "all", multiple_testing: tp.Union[bool, str] = "fdr_bh", test_upper_threshold: float = 0.05, test_lower_threshold: float = 0.01, plot_non_significant: bool = False, plot_kws: tp.Optional[tp.Dict[str, tp.Any]] = None, test_kws: tp.Optional[tp.Dict[str, tp.Any]] = None, fig_kws: tp.Optional[tp.Dict[str, tp.Any]] = None, tqdm: tp.Union[bool, tp.Dict[str, tp.Any]] = True, ) -> tp.Optional[tp.Union[Figure, DataFrame, tp.Tuple[Figure, DataFrame]]]: """ A categorical plot that overlays individual observations as a swarm plot and summary statistics about them in a boxen plot. In addition, this plot will test differences between observation groups and add lines representing a significant difference between them. Parameters ---------- data: pd.DataFrame A dataframe with data where the rows are the observations and columns are the variables to group them by. x: str The categorical variable. y: str | list[str] The continuous variable to plot. If more than one is given, will ignore the `ax` attribute and return figure with a subplot per each `y` variable. hue: str, optional An optional categorical variable to further group observations by. swarm: bool Whether to plot individual observations as a swarmplot. boxen: bool Whether to plot summary statistics as a boxenplot. bar: bool Whether to plot summary statistics as a barplot. orient: str Whether the plot should be oriented horizontally or vertically with relation to the numeric values `y`. - 'vertical': y-axis is `y` variable (numeric). - 'horizontal': x-axis is `y` variable (numeric). Default is 'vertical'. ax: matplotlib.axes.Axes, optional An optional axes to draw in. test: bool | str Whether to test differences between observation groups. If `False`, will not return a dataframe as well. If a string is passed, will perform test accordingly. Available tests: - 't-test' - 'mann-whitney' - 'kruskal' Default is a parwise 'mann-whitney' test with p-value adjustment. to_test: str Whether to test all possible combinations or just within `hue` groups for each `x`. Only relevant when `hue` is not None. - 'all': a model "y ~ x * hue", i.e. test between `x` groups, and within `hue` for each `x`. - 'hue': a model "y ~ x | hue", i.e. test within `hue` for each `x`. multiple_testing: str Method for multiple testing correction. test_upper_threshold: float Upper theshold to consider p-values significant. Will be marked with "*". test_lower_threshold: float Secondary theshold to consider p-values highly significant. Will be marked with "**". plot_non_significant: bool Whether to add a "n.s." sign to p-values above `test_upper_threshold`. plot_kws: dict Additional values to pass to seaborn.boxenplot or seaborn.swarmplot test_kws: dict Additional values to pass to pingouin.pairwise_tests. The default is: dict(parametric=False) to run a non-parametric test. tqdm: bool, dict Additional values to pass to pingouin.pairwise_tests. The default is: dict(parametric=False) to run a non-parametric test. Returns ------- tuple[Figure, pandas.DataFrame]: if `ax` is None and `test` is True. pandas.DataFrame: if `ax` is not None. Figure: if `test` is False. None: if `test` is False and `ax` is not None. Raises ------ ValueError: If either the `x` or `hue` column in `data` are not Category, string or object type, or if `y` is not numeric. """ # opts = dict(data=data, x='h', y='y', hue='x', test_kws=dict(parametric=False)) # opts = dict(data=data, x='cat', y='cont') # for k, v in opts.items(): # locals()[k] = v for var, name in [(x, "x"), (hue, "hue")]: if var is not None: if not data[var].dtype.name in ["category", "string", "object"]: raise ValueError( f"`{name}` variable must be categorical, string or object." ) if test_kws is None: test_kws = dict() if plot_kws is None: plot_kws = dict(palette="tab10") if isinstance(tqdm, bool): tqdm_kws = dict( disable=not tqdm, total=len(y) if not isinstance(y, str) else 1, desc="y" ) else: tqdm_kws = tqdm tqdm_kws["disable"] = False kw = dict(total=len(y) if not isinstance(y, str) else 1, desc="y") for k, v in kw.items(): if k not in tqdm_kws: tqdm_kws[k] = v data = data.sort_values([x] + ([hue] if hue is not None else [])) if isinstance(y, (list, pd.Series, pd.Index)): if plot: # TODO: display only one legend for hue if ax is None: n, m = get_grid_dims(y) default_fig_kws = dict( nrows=n, ncols=m, figsize=(m * 4, n * 4), sharex=True, squeeze=False ) default_fig_kws.update(fig_kws or {}) fig, axes = plt.subplots(**default_fig_kws) axes = axes.flatten() elif isinstance(ax, np.ndarray): axes = ax.flatten() elif isinstance(ax, matplotlib.axes.Axes): axes = np.asarray([ax]) else: axes = [None] * len(y) _stats = list() idx = -1 for idx, _var in _tqdm(enumerate(y), **tqdm_kws): _ax = axes[idx] s: DataFrame = swarmboxenplot( data=data, x=x, y=_var, hue=hue, swarm=swarm, boxen=boxen, bar=bar, orient=orient, plot=plot, ax=_ax, test=test, to_test=to_test, multiple_testing=multiple_testing, test_upper_threshold=test_upper_threshold, test_lower_threshold=test_lower_threshold, plot_non_significant=plot_non_significant, plot_kws=plot_kws, test_kws=test_kws, ) if plot: _ax.set(title=_var + _ax.get_title(), xlabel=None, ylabel=None) if test is not False: _stats.append(s.assign(Variable=_var)) # "close" excess subplots if plot: for _ax in axes[idx + 1 :]: _ax.axis("off") if test is not False: stats = pd.concat(_stats).reset_index(drop=True) cols = [c for c in stats.columns if c != "Variable"] stats = stats.reindex(["Variable"] + cols, axis=1) # If there is just one test per `y` (no hue), correct p-values if stats.shape[0] == len(y): stats["p-cor"] = pg.multicomp( stats["p-unc"].tolist(), method=multiple_testing )[1] if ax is None: return stats if not plot else (fig, stats) if test else fig return stats if test else None if data[y].dtype.name in ["category", "string", "object"]: raise ValueError("`y` variable must be numeric.") horizontal = orient in ["horizontal", "horiz", "h"] if horizontal: x2 = y y2 = x x = x2 y = y2 # Plot vanilla seaborn if plot: if ax is None: default_fig_kws = dict(figsize=(4, 4)) default_fig_kws.update(fig_kws or {}) fig, _ax = plt.subplots(**default_fig_kws) else: _ax = ax if boxen: assert not bar # Tmp fix for lack of support for Pandas Int64 in boxenplot: if data[y].dtype.name == "Int64": data[y] = data[y].astype(float) boxen_kws = filter_kwargs_by_callable(plot_kws, sns.boxenplot) sns.boxenplot(data=data, x=x, y=y, hue=hue, ax=_ax, **boxen_kws) if bar: assert not boxen bar_kws = filter_kwargs_by_callable(plot_kws, sns.barplot) sns.barplot(data=data, x=x, y=y, hue=hue, ax=_ax, **bar_kws) if (boxen or bar) and swarm: _add_transparency_to_plot(_ax, kind="bar" if bar else "boxen") if swarm: swarm_kws = filter_kwargs_by_callable(plot_kws, sns.swarmplot) if hue is not None and "dodge" not in swarm_kws: swarm_kws["dodge"] = True with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) sns.swarmplot( data=data, x=x, y=y, # use `hue` as `x` to have scatter colored accordingly hue=hue if hue is not None else x, legend="auto" if hue is not None else False, ax=_ax, **swarm_kws, ) if horizontal: _ax.set_yticklabels(_ax.get_yticklabels(), rotation=0, ha="right") else: _ax.set_xticklabels(_ax.get_xticklabels(), rotation=90, ha="right") if test is False: return fig if ax is None else None # Perform testing if horizontal: x2 = y y2 = x x = x2 y = y2 if test in [True, "t-test", "mann-whitney"]: test_function = pg.pairwise_tests if test == "mann-whitney": test_kws["parametric"] = False elif test in ["kruskal"]: test_function = pg.kruskal assert hue is None, "If test is 'kruskal', 'hue' must be None." else: raise ValueError(f"Test type '{test}' not recognized.") # if not data.index.is_unique: print("Warning: dataframe contains a duplicated index.") # # remove NaNs datat = data.dropna(subset=[x, y] + ([hue] if hue is not None else [])) # # remove categories with only one element keep = datat.groupby(x).size()[datat.groupby(x).size() > 1].index datat = datat.loc[datat[x].isin(keep), :] if datat[x].dtype.name == "category": datat[x] = datat[x].cat.remove_unused_categories() if plot: ylim = _ax.get_ylim() # save original axis boundaries for later ylength = abs(ylim[1]) + (abs(ylim[0]) if ylim[0] < 0 else 0) # # Now calculate stats # # # get empty dataframe in case nothing can be calculated stat = _get_empty_stat_results(datat, x, y, hue, add_median=True) # # # mirror groups to account for own pingouin order tats = stat.rename( columns={ "B": "A", "A": "B", "median_A": "median_B", "median_B": "median_A", } ) stat = ( pd.concat([stat, tats]) .sort_values(["Contrast", "A", "B"]) .reset_index(drop=True) ) try: _stat = test_function( data=datat, dv=y, between=x if hue is None else [x, hue], **test_kws, ) except (AssertionError, ValueError) as e: print(str(e)) _stat = stat except KeyError: print("Only one category with values!") _stat = stat if test == "kruskal": p = _stat.squeeze()["p-unc"] symbol = ( "**" if p <= test_lower_threshold else "n.s." if ((p > test_upper_threshold) or pd.isnull(p)) else "*" ) _ax.set_title(symbol) return (fig, _stat) if ax is None else _stat stat = _stat.merge( stat[ ["Contrast", "A", "B", "median_A", "median_B"] + ([x] if hue is not None else []) ], how="left", ).convert_dtypes() if to_test == "hue": stat = stat.loc[stat[x] != "-", :] if multiple_testing is not False: if "p-unc" not in stat.columns: stat["p-unc"] = np.nan stat["p-cor"] = pg.multicomp( stat["p-unc"].astype(float).values, method=multiple_testing )[1] pcol = "p-cor" else: pcol = "p-unc" if not plot: return stat # Plot # # This ensures there is a point for each `x` class and keep the order correct for below mm = data.groupby([x] + ([hue] if hue is not None else []))[y].median() if hue is None: order = {k: float(i) for i, k in enumerate(mm.index)} else: nhues = data[hue].drop_duplicates().dropna().shape[0] order = { k: (float(i) / nhues) - (1 / nhues) - 0.05 for i, k in enumerate(mm.index) } if horizontal: _ax.scatter(mm, order.values(), alpha=0, color="white") else: _ax.scatter(order.values(), mm, alpha=0, color="white") # # Plot significance bars # # # start at top of the plot and progressively decrease sig. bar downwards py = data[y].max() incr = ylength / 100 # divide yaxis in 100 steps for idx, row in stat.iterrows(): p = row[pcol] if (pd.isnull(p) or (p > test_upper_threshold)) and (not plot_non_significant): py -= incr continue symbol = ( "**" if p <= test_lower_threshold else "n.s." if ((p > test_upper_threshold) or pd.isnull(p)) else "*" ) if hue is not None: if row[x] != "-": xx = (order[(row[x], row["A"])], order[(row[x], row["B"])]) else: try: # TODO: get more accurate middle of group xx = ( order[(row["A"], stat["A"].iloc[-1])] - (1 / nhues), order[(row["B"], stat["B"].iloc[-1])] - (1 / nhues), ) except KeyError: # These are the hue groups without contrasting on 'x' continue else: xx = (order[row["A"]], order[row["B"]]) _tp = (0.35 + xx[0], 0.35 + xx[1] - 0.25), (py, py) _tp2 = xx[1] - 0.025, py if horizontal: _tp = _tp[::-1] _tp2 = _tp2[::-1] _ax.plot(*_tp, color="black", linewidth=1.2) _ax.text( *_tp2, s=symbol, color="black", ha="center", rotation=90 if orient in ["horizontal", "horiz", "h"] else 0, ) py -= incr _ax.set_ylim(ylim) return (fig, stat) if ax is None else stat
def _add_transparency_to_plot( ax: Axis, alpha: float = 0.25, kind: str = "boxen" ) -> None: objs = ( ( matplotlib.collections.PatchCollection, matplotlib.collections.PathCollection, ) if kind == "boxen" else (matplotlib.patches.Rectangle) ) for x in ax.get_children(): if isinstance(x, objs): x.set_alpha(alpha) def _get_empty_stat_results( data: DataFrame, x: str, y: str, hue: tp.Optional[str] = None, add_median: bool = True, ) -> DataFrame: stat = pd.DataFrame( itertools.combinations(data[x].drop_duplicates(), 2), columns=["A", "B"], ) stat["Contrast"] = x if hue is not None: huestat = pd.DataFrame( itertools.combinations(data[hue].drop_duplicates(), 2), columns=["A", "B"], ) huestat["Contrast"] = hue huestat[x] = "-" _to_append = [huestat] for v in data[x].unique(): n = huestat.copy() n[x] = v n["Contrast"] = f"{x} * {hue}" _to_append.append(n) to_append = pd.concat(_to_append) stat = pd.concat([stat, to_append]).sort_values([x, "A", "B"]) stat[x] = stat[x].fillna("-") stat["Tested"] = False stat["p-unc"] = np.nan if add_median: _mm = [data.groupby(x)[y].median().reset_index()] if hue is not None: _mm[0] = _mm[0].rename(columns={x: hue}) _mm.append(data.groupby(hue)[y].median().reset_index()) _p = data.groupby([x, hue])[y].median().reset_index() # remove categories if existing (workaround): _p = pd.DataFrame(_p.values, index=_p.index, columns=_p.columns) _mm.append(_p) mm = pd.concat(_mm) if mm[x].dtype.name == "category": mm[x] = mm[x].cat.add_categories(["-"]).fillna("-") else: mm[x] = mm[x].fillna("-") # mm = mm.append(data.groupby([x, hue])[y].std().reset_index()).fillna("-") for col in ["A", "B"]: stat = stat.merge( mm.rename( columns={ hue if hue is not None else x: f"{col}", y: f"median_{col}", } ), how="left", ) return stat