##############################################################################
# Braviz, Brain Data interactive visualization #
# Copyright (C) 2014 Diego Angulo #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU Lesser General Public License as #
# published by the Free Software Foundation, either version 3 of the #
# License, or (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU Lesser General Public License for more details. #
# #
# You should have received a copy of the GNU Lesser General Public License#
# along with this program. If not, see <http://www.gnu.org/licenses/>. #
##############################################################################
from __future__ import division
__author__ = 'Diego'
from PyQt4 import QtCore
from PyQt4 import QtGui
import matplotlib
from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg
from matplotlib.figure import Figure
import matplotlib.axes
import matplotlib.gridspec as gridspec
from itertools import izip
import numpy as np
import pandas as pd
import logging
import seaborn as sns
[docs]class AbstractPlot(object):
"""
Base class for plots used inside the :class:`MatplotWidget`
"""
[docs] def redraw(self):
"""
Should redraw its contents, called when the widget is resized
"""
raise NotImplementedError("must be implemented")
[docs] def add_subjects(self, subjs):
"""
Should highlight the specified points in the plot
Args:
subjs (list) : List of subjects to highlight
"""
return None
[docs] def highlight(self, subj):
"""
Should highlight one point in the plot
Args:
subj : Id of point to highlight
"""
return None
[docs] def get_last_id(self):
"""
Get the id of the point last signaled with the cursor. This is used by the MatplotWidget to create
a context menu
Returns:
Id of the last point for which a tooltip was requested
"""
return None
[docs]class MatplotBarPlot(AbstractPlot):
"""
Draws a bar plot on the :class:`MatplotWidget`.
Bars are sorted from smallest to biggest,
they also may be colored with respect to a nominal variable.
To create a bar plot call :meth:`MatplotWidget.draw_bars`
"""
def __init__(self, axes, data, ylims=None, orientation="vertical", group_labels=None):
sns.set_style("darkgrid")
self.highlight_color = '#000000'
self.highlighted = None
self.axes = axes
self.orientation = orientation
self.group_labels = group_labels
self.grouped = True if data.shape[1] >= 2 else False
assert isinstance(self.axes, matplotlib.axes.Axes)
self.axes.cla()
if ylims is None:
maxi = data.max()[0]
mini = 0
span = maxi - mini
ylims = (0, maxi + 0.1 * span)
col0 = data.columns[0]
self.col0 = col0
ix_name = data.index.name
if self.orientation == "vertical":
self.axes.set_ylim(*ylims)
self.axes.tick_params(
'y', left='off', right='on', labelleft='off', labelright='on')
self.axes.tick_params(
'x', top='off', bottom='on', labelbottom='on', labeltop='off')
self.axes.get_yaxis().set_label_position("right")
self.axes.set_ylabel(col0)
if ix_name is not None:
self.axes.set_xlabel(ix_name)
else:
self.axes.set_xlim(*ylims)
self.axes.tick_params(
'y', left='on', right='off', labelleft='on', labelright='off')
self.axes.tick_params(
'x', top='off', bottom='on', labelbottom='on', labeltop='off')
self.axes.get_yaxis().set_label_position("left")
self.axes.set_xlabel(col0)
if ix_name is not None:
self.axes.set_ylabel(ix_name)
# sort data
data2 = data.dropna()
if self.orientation == "vertical":
data2.sort(col0, ascending=False, inplace=True)
else:
data2.sort(col0, ascending=True, inplace=True)
heights = data2[col0].get_values()
pos = np.arange(len(heights))
data2.loc[:, "_pos"] = pos
# create colors
# colors_list=matplotlib.rcParams['axes.color_cycle']
# if data2.shape[1]>=2:
# groups_col = data2.columns[1]
# unique_indexes = data2[groups_col].unique()
# unique_map = dict(izip(unique_indexes,range(len(unique_indexes))))
# colors = [colors_list[unique_map[i]] for i in data2[groups_col]]
# else:
# colors = colors_list[0]
self.axes.axhline(ylims[0], color=self.highlight_color)
self.data = data2
self.pos = pos
self.heights = heights
#self.colors = colors
groups = self.data.groupby(self.data.columns[1])
self.colors_list = sns.color_palette('Set1', len(groups))
self.colors_dict = dict(
(n, self.colors_list[i]) for i, (n, g) in enumerate(groups) if len(g) > 0)
self.last_id = None
self.redraw()
def redraw(self):
# main plot
###################
self.axes.cla()
log = logging.getLogger(__name__)
colors_list = self.colors_list
if self.grouped is False:
self.__draw_bars_and_higlight(
self.data, "_nolegend_", colors_list[0])
else:
groups = self.data.groupby(self.data.columns[1])
for i, (name, group) in enumerate(groups):
if len(group) > 0:
label = self.group_labels[
name] if self.group_labels is not None else None
if label is None or len(label) == 0:
label = "Level %s" % name
log.debug(label)
self.__draw_bars_and_higlight(group, label, colors_list[i])
if self.orientation == "vertical":
self.axes.set_xticklabels(self.data.index)
self.axes.set_xticks(self.pos)
self.axes.set_xlim(-0.5, len(self.pos) - 0.5)
else:
self.axes.set_yticklabels(self.data.index)
self.axes.set_yticks(self.pos)
self.axes.set_ylim(-0.5, len(self.pos) - 0.5)
ix_name = self.data.index.name
if self.orientation == "vertical":
self.axes.tick_params(
'y', left='off', right='on', labelleft='off', labelright='on')
self.axes.tick_params(
'x', top='off', bottom='on', labelbottom='on', labeltop='off')
self.axes.get_yaxis().set_label_position("right")
self.axes.set_ylabel(self.col0)
if ix_name is not None:
self.axes.set_xlabel(ix_name)
else:
self.axes.tick_params(
'y', left='on', right='off', labelleft='on', labelright='off')
self.axes.tick_params(
'x', top='off', bottom='on', labelbottom='on', labeltop='off')
self.axes.get_yaxis().set_label_position("left")
self.axes.set_xlabel(self.col0)
if ix_name is not None:
self.axes.set_ylabel(ix_name)
if self.grouped is True:
self.axes.legend(loc="lower right")
def __draw_bars_and_higlight(self, data, label, color):
if self.orientation == "vertical":
patches = self.axes.bar(data["_pos"].values, data[
self.col0].values, align="center", picker=5, color=color)
else:
patches = self.axes.bar(left=None, bottom=data["_pos"].values, width=data[self.col0].values, align="center", picker=5,
orientation=self.orientation, height=0.8, label=label, color=color)
for i, p in enumerate(patches):
p.set_url(data.index[i])
if data.index[i] == self.highlighted:
p.set_linewidth(2)
p.set_ec(self.highlight_color)
def highlight(self, subj):
self.highlighted = subj
def get_tooltip(self, event):
subj = event.artist.get_url()
self.last_id = subj
data = self.data
col0 = data.columns[0]
message_rows = ["%s:" % subj]
# value
row = "%s : %.2f" % (col0, data.ix[subj, col0])
message_rows.append(row)
# group?
if self.grouped:
col1 = data.columns[1]
label = data.ix[subj, col1]
if self.group_labels is not None:
label = self.group_labels[label]
row = "%s : %s" % (col1, label)
message_rows.append(row)
message = "\n".join(message_rows)
# print message
return message
def get_last_id(self):
return self.last_id
[docs]class CoefficientsPlot(AbstractPlot):
"""
Draws a coefficient plot to illustrate the results of a linear regression.
The plot shows the 95% confidence intervals and standard errors. For a coefficient to
be significant it's confidence intervals should not cross the zero line. For it to have an important
effect it should be far from the zero.
The input DataFrame should contain the results of a linear regression with normalized variables.
The expected columns are
- (index) : Coefficient names
- CI_95 : lower and upper limit of the 95% confidence interval
- Std_error : The standard error magnitude
- Slope : slope of the coefficients in the regression
Also the first row in the dataframe should be the intercept, this will be ignored if *intercept* is ``False``.
Use :meth:`MatplotWidget.draw_coefficients_plot` to draw create this plot.
"""
def __init__(self, axes, coefs_df, draw_intercept=False):
sns.set_style("darkgrid")
self.axes = axes
if draw_intercept is False:
self._df = coefs_df.iloc[1:].copy()
else:
self._df = coefs_df.copy()
self.centers = self._df.Slope
self.l95 = [i[0] for i in self._df.CI_95]
self.h95 = [i[1] for i in self._df.CI_95]
self.l68 = self.centers - self._df.Std_error
self.h68 = self.centers + self._df.Std_error
self.names = list(self._df.index)
self.n_coefs = len(self._df)
self.pos = range(self.n_coefs)
self.color = matplotlib.rcParams['axes.color_cycle'][1]
self.axes.tick_params(
'x', bottom='on', labelbottom='on', labeltop='off', top='off')
self.axes.tick_params(
'y', left='on', labelleft='on', labelright='off', right="off")
self.axes.yaxis.set_label_position("right")
self.redraw()
def redraw(self):
self.axes.clear()
self.axes.set_ylim(-0.5, self.n_coefs - 0.5, auto=False)
self.axes.set_xlim(-1, 1, auto=True)
self.axes.axvline(0, ls="--", color=(0.4, 0.4, 0.4))
self.axes.minorticks_off()
# draw 95
for p, l, h in izip(self.pos, self.l95, self.h95):
self.axes.plot([l, h], [p, p], color=self.color,
solid_capstyle="round", lw=1, zorder=1, picker=0.5)
# draw 68
for p, l, h in izip(self.pos, self.l68, self.h68):
self.axes.plot(
[l, h], [p, p], c=self.color, solid_capstyle="round", lw=2.5, zorder=5)
# draw centers
self.axes.plot(
self.centers, self.pos, "o", ms=8, zorder=10, c=self.color)
# ticks
self.axes.set_yticks(self.pos)
self.axes.set_yticklabels(self.names)
self.axes.set_xlabel("Standardized coefficients")
def get_tooltip(self, event):
y_coord = event.mouseevent.ydata
i = int(round(y_coord))
try:
name = self.names[i]
slope = self.centers[i]
message = "%s: %.2g" % (name, slope)
return message
except IndexError:
return ""
[docs]class ResidualsDiagnosticPlot(AbstractPlot):
"""
Creates two plots to analyze distributions of residuals from a regression.
The first one shows the distribution of the residuals with respect to the outcome variable. This should be used
to check the hypothesis that the variance must be constant across this range.
The second one shows a histogram of the residuals. This should be used to verify that the residuals distribution is
close to normal.
To create this plot call :class:`MatplotWidget.draw_residuals`
"""
def __init__(self, figure, residuals, fitted, names=None):
sns.set_style("darkgrid")
self.names = names
figure.clear()
self.fig = figure
gs = gridspec.GridSpec(1, 2, width_ratios=(2, 1))
self.axes = self.fig.add_subplot(gs[1])
self.axes.clear()
self.axes.tick_params(
'x', bottom='on', labelbottom='on', labeltop='off', top='off')
self.axes.tick_params(
'y', left='off', labelleft='off', labelright='off', right="off")
self.axes.yaxis.set_label_position("right")
self.axes.set_ylim(auto=True)
# self.axes.set_ylabel("Residuals")
self.axes.set_xlabel("Frequency")
self.axes2 = self.fig.add_subplot(gs[0], sharey=self.axes)
self.axes2.tick_params(
'x', bottom='on', labelbottom='on', labeltop='off', top='off')
self.axes2.tick_params(
'y', left='on', labelleft='on', labelright='off', right="off")
self.axes2.set_ylabel("Residuals")
self.axes2.set_xlabel("Fitted")
self.axes2.yaxis.set_label_position("left")
self.axes2.axhline(color='k')
self.subject_markers = None
self.residuals = residuals
self.fitted = fitted
self.redraw()
def get_tooltip(self, event):
if self.names is None:
return ""
if event.mouseevent.inaxes == self.axes2:
ind = event.ind
names = ["%s" % self.names[i] for i in ind]
return "\n".join(names)
def redraw(self):
residuals, fitted = self.residuals, self.fitted
self.axes.hist(
residuals, color="#2ca25f", bins=20, orientation="horizontal")
self.axes2.scatter(
fitted, residuals, s=20, color="#2ca25f", picker=0.5)
[docs]class MessagePlot(AbstractPlot):
"""
Draws a text message into a :class:`MatplotWidget`
To create this plot call :class:`MatplotWidget.draw_scatter`
"""
def __init__(self, axes, message):
sns.set_style("darkgrid")
self.axes = axes
self.axes.set_ylim(0, 1)
self.axes.set_xlim(0, 1)
self.message = message
self.axes.tick_params(
'x', bottom='off', labelbottom='off', labeltop='off', top='off')
self.axes.tick_params(
'y', left='off', labelleft='off', labelright='off', right="off")
self.axes.yaxis.set_label_position("right")
self.redraw()
def redraw(self):
message = self.message
self.axes.text(0.5, 0.5, message, horizontalalignment='center',
verticalalignment='center', fontsize=16)
[docs]class ScatterPlot(AbstractPlot):
"""
Draws an scatter plot in :class:`MatplotWidget`.
The plot may contain
- a line showing regression results
- data from different groups painted with different colors
To create this plot call :class:`MatplotWidget.draw_scatter`
"""
def __init__(self, axes, data, x_var, y_var, xlabel=None, ylabel=None, reg_line=True, hue_var=None, hue_labels=None,
qualitative_map=True, x_ticks=None):
sns.set_style("darkgrid")
self.x_name = x_var
self.y_name = y_var
self.z_name = hue_var
if xlabel is None:
xlabel = x_var
if ylabel is None:
ylabel = y_var
self.df = data.copy()
self.axes = axes
self.reg_line = reg_line
self.axes.tick_params(
'x', bottom='on', labelbottom='on', labeltop='off', top='off')
self.axes.tick_params(
'y', left='off', labelleft='off', labelright='on', right="on")
self.axes.yaxis.set_label_position("right")
self.axes.set_ylabel(ylabel)
self.axes.set_xlabel(xlabel)
self.axes.set_xlim(auto=True)
self.axes.set_ylim(auto=True)
self.color = matplotlib.rcParams['axes.color_cycle'][0]
self.hue_labels = hue_labels
self.qualitative_map = qualitative_map
self.x_ticks = x_ticks
self.subject_markers = None
self.last_id = None
self.to_highlight = None
self.redraw()
def redraw(self):
self.axes.clear()
if self.z_name is None:
url = self.df.index
sns.regplot(self.x_name, self.y_name, data=self.df, fit_reg=self.reg_line,
scatter_kws={"picker": 0.5, "url": url}, ax=self.axes,
color=self.color)
else:
xlim = (self.df[self.x_name].min(), self.df[self.x_name].max())
xrange = xlim[1] - xlim[0]
xlim = (xlim[0]-xrange/20, xlim[1]+xrange/20)
ylim = (self.df[self.y_name].min(), self.df[self.y_name].max())
yrange = ylim[1] - ylim[0]
ylim = (ylim[0]-yrange/20, ylim[1]+yrange/20)
self.axes.set_xlim(xlim)
self.axes.set_ylim(ylim)
self.artists_dict = dict()
unique_levels = np.unique(self.df[self.z_name])
n_levels = len(unique_levels)
if self.qualitative_map:
colors = sns.color_palette("Dark2", n_levels)
else:
# first one is too light
colors = sns.color_palette("YlOrRd", n_levels + 1)[1:]
for c, l in izip(colors, unique_levels):
df2 = self.df[self.df[self.z_name] == l]
if self.hue_labels is not None:
label = self.hue_labels.get(int(l), "?")
else:
label = "?"
url = df2.index
sns.regplot(self.x_name, self.y_name, data=df2, fit_reg=self.reg_line,
scatter_kws={"picker": 0.5, "url": url}, label=label, ax=self.axes,
color=c)
self.add_legend()
log = logging.getLogger(__name__)
log.info(self.x_ticks)
if self.x_ticks is not None:
keys, labels = zip(*self.x_ticks.iteritems())
self.axes.set_xticks(keys)
self.axes.set_xticklabels(labels)
if self.to_highlight is not None:
self.add_subjects(self.to_highlight)
def add_subjects(self, subjs):
if self.subject_markers is not None:
try:
self.subject_markers.remove()
except ValueError:
pass
try:
subjs_df = self.df.loc[subjs]
except KeyError:
log = logging.getLogger(__name__)
log.info("subject %s not found",subjs)
self.subject_markers = None
else:
x_coords = subjs_df[self.x_name]
y_coords = subjs_df[self.y_name]
self.subject_markers = self.axes.scatter(x_coords, y_coords, marker="o", s=120, edgecolors="k", alpha=0.80, zorder=40,
linewidths=2)
self.subject_markers.set_facecolor('none')
def highlight(self, subj):
AbstractPlot.highlight(self, subj)
self.to_highlight = [subj]
def get_tooltip(self, event):
if event.mouseevent.inaxes == self.axes:
ind = event.ind
urls = event.artist.get_url()
names = ["%s" % urls[i] for i in ind]
#names = ["%s"%self.names[i] for i in ind]
self.last_id = urls[ind[0]]
return "\n".join(names)
def add_legend(self):
if self.hue_labels is None:
return
self.axes.legend(title=self.z_name)
def get_last_id(self):
return self.last_id
[docs]class InterceptPlot(AbstractPlot):
"""
Draws a plot to show the mean of different data groups
Optionally a confidence interval can be added.
To create this plot call :class:`MatplotWidget.draw_intercept`
"""
def __init__(self, axes, data, y_var, groups=None, y_label=None, ci_plot=True, color=None, group_labels=None):
sns.set_style("darkgrid")
self.y_name = y_var
if y_label is None:
y_label = y_var
self.df = data.copy()
self.axes = axes
self.ci_plot = ci_plot
self.axes.tick_params(
'x', bottom='off', labelbottom='off', labeltop='off', top='off')
self.axes.tick_params(
'y', left='off', labelleft='off', labelright='on', right="on")
self.axes.yaxis.set_label_position("right")
self.axes.set_ylabel(y_label)
self.axes.set_xlabel("")
self.axes.set_ylim(auto=True)
self.internal_df = self.df[[y_var]]
self.subject_markers = None
self.internal_df.columns = ["y_data"]
self.last_viewed = None
if color is None:
self.color = matplotlib.rcParams['axes.color_cycle'][1]
else:
self.color = color
# calculate mean and confidence intervals
self.x_ticks = None
if groups is None:
arrays = [data[y_var]]
urls = [data.index]
self.x_data = [
i + 0.3 + 0.4 * np.random.random(len(a)) for i, a in enumerate(arrays)]
self.internal_df["x_data"] = self.x_data[0]
else:
g_i = np.unique(data[groups])
arrays = [data[y_var][data[groups] == g].get_values() for g in g_i]
if group_labels is not None:
self.x_ticks = [group_labels.get(g, "?") for g in g_i]
urls = [data.index[data[groups] == g] for g in g_i]
self.x_data = [
i + 0.3 + 0.4 * np.random.random(len(a)) for i, a in enumerate(arrays)]
x_ser = pd.Series(index=self.internal_df.index)
for xd, ix in izip(self.x_data, urls):
x_ser[ix] = xd
self.internal_df["x_data"] = x_ser
if self.ci_plot is True:
self.ci = [_get_ci(a) for a in arrays]
self.means = [np.mean(a) for a in arrays]
self.arrays = arrays
self.urls = urls
self.redraw()
def get_tooltip(self, event):
urls = event.artist.get_url()
ind = event.ind
subj = urls[ind][0]
self.last_viewed = subj
return "%s" % subj
def redraw(self):
self.axes.clear()
ax = self.axes
c = self.color
for i in xrange(len(self.arrays)):
y_data = self.arrays[i]
x_data = self.x_data[i]
m = self.means[i]
url = self.urls[i]
l, h = self.ci[i]
# scatter
ax.scatter(
x_data, y_data, color=c, url=url, picker=0.5, zorder=10, alpha=0.9)
# line
ax.hlines(m, i, i + 1, c, zorder=5)
# error
if self.ci_plot is True:
ax.fill_between(
(i, i + 1), (l, l), (h, h), zorder=4, color=c, alpha=0.2)
ax.set_xlim(0, len(self.arrays), auto=False)
if self.x_ticks is not None:
ax.set_xticks(np.arange(0.5, len(self.arrays), 1))
ax.set_xticklabels(self.x_ticks)
self.axes.tick_params(
'x', bottom='on', labelbottom='on', labeltop='off', top='off')
def add_subjects(self, subjs):
if self.subject_markers is not None:
self.subject_markers.remove()
subjs_df = self.internal_df.loc[subjs]
x_coords = subjs_df["x_data"]
y_coords = subjs_df["y_data"]
self.subject_markers = self.axes.scatter(x_coords, y_coords, marker="o", s=120, edgecolors="k", alpha=0.80, zorder=40,
linewidths=2)
self.subject_markers.set_facecolor('none')
def get_last_id(self):
return self.last_viewed
def _get_ci(array):
bootstrap = sns.algo.bootstrap(array, func=np.mean)
ci = tuple(sns.utils.ci(bootstrap))
return ci
if __name__ == "__main__":
# init widget
app = QtGui.QApplication([])
# show bar plot
values = np.random.rand(10)
groups = np.random.randint(1, 3, 10)
data = pd.DataFrame(
{"test": values, "group": groups}, columns=["test", "group"])
widget = MatplotWidget()
widget.show()
widget.draw_bars(
data, orientation="horizontal", group_labels={1: "One", 2: "Two"})
# widget.draw_bars(data,orientation="vertical")
app.exec_()