Search
Fig 2. Comparison of RNN and Bayesian decoders

This code example is a Jupyter notebook with Script of Scripts (SoS) workflow.

The calculations are written using Python 2.7 (from author's repo), and the interactive figures are written in Python 3.6 with Plotly.

Figure 2:

(a) Histogram of error sizes, generated in each case with the best performing time window (1400 ms for RNN, 2800 ms for flat Bayesian and 2000ms for Bayesian with memory). Both types of Bayesian decoders make more very large errors (0.02% vs 2.7% of errors > 50 cm). Errors are grouped into 2 cm bins, the last bin shows all errors above 50 cm. (b-c) Downsampling analysis demonstrates the RNN decoder is more robust to small dataset sizes. Data from R2192 were downsampled and all three decoders were trained with a random subset of the available neurons. For each sample size, 10 random sets of neurons were selected and independent models trained as before using 10-fold cross validation. Dots represent (b) mean and (c) median error for each downsampled dataset. Lines indicate the (b) mean of means and (c) mean of medians over sets of the same size.

https://doi.org/10.1371/journal.pcbi.1006822.g002

%use Python2
import numpy as np
import matplotlib.pyplot as plt
import pickle 
%matplotlib inline
from scipy.stats import binned_statistic
import math

import scipy.io as scio

import matplotlib.style
import matplotlib as mpl
mpl.style.use('classic')


# These are the "Tableau 20" colors as RGB.    
tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),    
             (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),    
             (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),    
             (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),    
             (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]    
  
# Scale the RGB values to the [0, 1] range, which is the format matplotlib accepts.    
for i in range(len(tableau20)):    
    r, g, b = tableau20[i]    
    tableau20[i] = (r / 255., g / 255., b / 255.)
    
def mse(y, t, axis=2):
    return (np.square(y - t).mean(axis=axis).mean())

def mean_distance(y, t, axis=2):
    return np.mean(np.sqrt(np.sum((y - t)**2, axis=axis)))

def median_distance(y, t, axis=2):
    return np.median(np.sqrt(np.sum((y - t)**2, axis=axis)))

cvfolds = 10
data = np.load("R2192_models/window_scan_R2192_1x1400_predictions_v1.npz")

# print type(data), data.keys()

pred_y = data["preds"]
y = data["targets"]
targets = data["targets"]

# print pred_y.shape, targets.shape

seqlen = pred_y.shape[1]
pred_y = pred_y.reshape((-1, 2))

errors = np.sqrt(np.sum((pred_y - y)**2, axis=1))

# WHAT IS CHANCE LEVEL ACCURACY?
r_y = y.copy()
np.random.shuffle(r_y)


all_preds = []
all_errors =[]

# assumes there are predictions from 10 CV-runs in the folder R2192_models
for i in range(1,11):
    data = np.load("R2192_models/window_scan_R2192_1x1400_predictions_v"+str(i)+".npz")
    pr_i = data["preds"]
    tgts = data["targets"]
    assert np.all(tgts==targets) #should always be same
    
    all_preds.append(pr_i)
    errors_i = np.sqrt(np.sum((pr_i - y)**2, axis=1))
    all_errors.append(errors_i) # we record errors of each model individually, not the error made by ensemble model	

    
# print "we have in total error size: ", np.shape(all_preds), np.shape(all_errors), np.mean(np.ravel(all_errors))
avg_pred = np.mean(all_preds,axis=0) # this gives us the prediction of ensemble model
avg_errors = np.mean(all_errors,axis=0) #this gives us the average of errors for each location, not error of ensemble model!
median_errors = np.median(all_errors,axis=0)

# print "avg_errors shape", np.shape(avg_errors), np.mean(avg_errors), np.median(avg_errors)
# print "avg_pred shape", np.shape(avg_pred)


#calculating errors for ensemble model
# print "\n \n ### Ensemble model -- not part of article ###"
ensemble_errors = np.sqrt(np.sum((avg_pred - y)**2, axis=1))
# print "Ensemble errors for full model (mean, median, shape)", np.mean(ensemble_errors),np.median(ensemble_errors), ensemble_errors.shape
np.save("predictions_by_ensemble_of_10_R2335models.npy", avg_pred)

### READ IN ERRORS MADE BY BAYES WITH MEMORY
# bayes_mem_preds = scio.loadmat("Bayes_results/BayesMemory_R2192_2200ms.mat") #-- with wrong history size
bayes_mem_preds = scio.loadmat("Bayes_results/Fig2a_2d2-2TimeWindowFullErrorWithBayesMem.mat")
# print bayes_mem_preds.keys()
# print len(bayes_mem_preds["animalStruct"][0][0][1][0][0][1])
b_mem_errors = bayes_mem_preds["animalStruct"][0][0][1][0][0][3]
# print np.mean(b_mem_errors)
# print np.median(b_mem_errors)

# READ IN TRUE AND PERDICTED LOC WITH BAYES FLAT
x = scio.loadmat("Bayes_results/Fig2a_2dDecode2cmBins.mat")
# print x.keys()
# print "13th corresponds to 2800ms", x["animalStruct"][0][0][1][0][13][0]
true_loc = x["animalStruct"][0][0][1][0][13][1]
pred_loc = x["animalStruct"][0][0][1][0][13][2]
errors= np.sqrt(np.sum((true_loc-pred_loc)**2,axis=1))

# print np.mean(errors),np.mean(avg_errors)
# print np.median(errors),np.median(avg_errors)

# The true and predicted vals are also saved in .txt file
#predR2192 = np.loadtxt("Bayes_results/R2192_2800ms_flatBayes_predicted_loc.txt")
#trueR2192 = np.loadtxt("Bayes_results/R2192_2800ms_flatBayes_true_loc.txt")
#errors= np.sqrt(np.sum((predR2192-trueR2192)**2,axis=1))
errors_clipped= np.clip(errors,0,50)

##########################################################################################################

bayes_flat_downsample = scio.loadmat("Bayes_results/Fig2bc_downSampResultsCompl.mat")
# print bayes_flat_downsample.keys()
bayes_flat_downsample_means = bayes_flat_downsample["allMeanErr"]

bayes_flat_downsample_medians = bayes_flat_downsample["allMedianErr"]
#print bayes_flat_downsample["dwnSamp"]

#bayes_mem_downsample = scio.loadmat("Bayesian_results/downSampResultsWithBayesHistory.mat")
bayes_mem_downsample = scio.loadmat("Bayes_results/Fig2bc_downSampResultsWithBayesHistoryTWin1-2.mat")

# print bayes_mem_downsample.keys()
bayes_mem_downsample_means = bayes_mem_downsample["allMeanErr"]
bayes_mem_downsample_medians = bayes_mem_downsample["allMedianErr"]

f=np.loadtxt("R2192_models/downsampling_results/means_and_medians.txt",dtype=str)
mea_and_med = f[:,[-5,-1]]
mea_and_med[:,0] = map(lambda x: float(x[:-1]),mea_and_med[:,0] )
mea_and_med[:,1] = map(lambda x: float(x[:-1]),mea_and_med[:,1] )

mea_and_med= np.array(mea_and_med,dtype=float)

sample_sizes=[5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
              10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
              15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
              20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
              25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
              30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
              35, 35, 35, 35, 35, 35, 35, 35, 35, 35,
              40, 40, 40, 40, 40, 40, 40, 40, 40, 40,
              45, 45, 45, 45, 45, 45, 45, 45, 45, 45,
              50, 50, 50, 50, 50, 50, 50, 50, 50, 50,
              55, 55, 55, 55, 55, 55, 55, 55, 55, 55]

bay_flat_means = bayes_flat_downsample_means


bay_mem_means = bayes_mem_downsample_means



mean_of_means = np.mean(np.reshape(mea_and_med[:,0],(11,10)),axis=1)

##########################################################################################################
bay_flat_medians = bayes_flat_downsample_medians

bay_mem_medians = bayes_mem_downsample_medians
X = [5,10,15,20,25,30,35,40,45,50,55]
Y = np.mean(bay_mem_medians,axis=1)

mean_of_medians = np.mean(np.reshape(mea_and_med[:,1],(11,10)),axis=1)

with open('train.pickle', 'wb') as f:
    pickle.dump([tableau20, sample_sizes, bay_flat_means, bay_mem_means, mean_of_means, mea_and_med, bayes_flat_downsample_means, bayes_mem_downsample_means, 
                 bayes_flat_downsample_medians, bayes_mem_downsample_medians, errors, avg_errors, b_mem_errors], f)

(a) Distribution of errors with each method

%use Python3
import plotly.graph_objects as go
import plotly.tools as tls
from plotly.offline import plot, iplot, init_notebook_mode
from IPython.core.display import display, HTML
import numpy as np
init_notebook_mode(connected = True)
config={'showLink': False, 'displayModeBar': False}

with open('train.pickle', 'rb') as f:
    tableau20, sample_sizes, bay_flat_means, bay_mem_means, mean_of_means, mea_and_med, bayes_flat_downsample_means, bayes_mem_downsample_means, bayes_flat_downsample_medians, bayes_mem_downsample_medians, errors, avg_errors, b_mem_errors = pickle.load(f, encoding='bytes')


figa = go.Figure()

figa.add_trace(go.Histogram(x=np.clip(avg_errors,0,52),
                            histnorm='probability',
                            name='Errors by RNN decoder',
                            showlegend = True, 
                            xbins=dict(
                                start=0.0,
                                end=52.1,
                                size=2.0),
                            marker_color="rgb"+str(tableau20[0]),
                            opacity=0.75,
                             hovertemplate = '<b>x: </b> <i> %{x} cm </i> <br><b>y: </b> <i>%{y}</i>',
                            marker=dict(line=dict(width=0.4,
                                                  color='black'))))

figa.add_trace(go.Histogram(x=np.clip(errors,0,52),
                            histnorm='probability',
                            name='Errors by Bayesian with flat prior (MLE)',
                             hovertemplate = '<b>x: </b> <i> %{x} cm </i> <br><b>y: </b> <i>%{y}</i>',
                            showlegend = True, 
                            xbins=dict(
                                start=0.0,
                                end=52.1,
                                size=2.0),
                            marker_color="rgb"+str(tableau20[13]),
                            opacity=0.75, 
                            alignmentgroup = "Errors by RNN decoder", 
                            marker=dict(line=dict(width=0.4,
                                                  color='black'))))

figa.add_trace(go.Histogram(x=np.clip(b_mem_errors,0,52).flatten(),
                            histnorm='probability',
                            name='Errors by Bayesian with memory',
                             hovertemplate = '<b>x: </b> <i> %{x} cm </i> <br><b>y: </b> <i>%{y}</i>',
                            showlegend = True, 
                            xbins=dict(
                                start=0.0,
                                end=52.1,
                                size=2.0),
                            marker_color="rgb"+str(tableau20[6]),
                            opacity=0.75, 
                            alignmentgroup = "Errors by RNN decoder",
                            marker=dict(line=dict(width=0.4,
                                                  color='black'))))

figa.update_layout(title = '(a)   Distribution of errors with each method', 
                   title_x = 0.5, 
                   xaxis_title='Error (cm)',
                   xaxis=dict(range=[-0.1,52], 
                              mirror=True,
                              ticks='outside',
                              showline=True,
                              tickvals = [0,10,20,30,40,50],
                              linecolor='#000', 
                              tickfont = dict(size=16)),
                   yaxis_title='Proportion of errors',
                   yaxis=dict(range=[0,0.16],
                              mirror=True,
                              ticks='outside', 
                              showline=True,
                              tickvals = [0.0,0.02,0.04,0.06,0.08,0.10,0.12,0.14], 
                              linecolor='#000',
                              tickfont = dict(size=16)),
                   legend=dict(x=0.22, 
                              y=0.95,
                              bordercolor="Gray",
                              borderwidth=1),
                  plot_bgcolor='#fff', 
                  width = 550, 
                  height = 400,
                  font = dict(size = 13),
                  margin=go.layout.Margin(l=50,
                                         r=50,
                                         b=60,
                                         t=35))
figa.update_traces()

plot(figa, filename = 'fig2_a.html', config = config)
# THEBELAB
display(HTML('fig2_a.html'))
# BINDER
# iplot(figa,config=config)

(b) Mean results with downsampled datasets

%use Python3
import plotly.graph_objects as go
import plotly.tools as tls
from plotly.offline import plot, iplot, init_notebook_mode
from IPython.core.display import display, HTML
init_notebook_mode(connected = True)
config={'showLink': False, 'displayModeBar': False}

with open('train.pickle', 'rb') as f:
    tableau20, sample_sizes, bay_flat_means, bay_mem_means, mean_of_means, mea_and_med, bayes_flat_downsample_means, bayes_mem_downsample_means, bayes_flat_downsample_medians, bayes_mem_downsample_medians, errors, avg_errors, b_mem_errors = pickle.load(f, encoding='bytes')


figb = go.Figure()

# 1st line 
figb.add_trace(go.Scatter(x = [5,10,15,20,25,30,35,40,45,50,55], 
                         y = np.mean(bay_flat_means,axis=1), 
                         mode = 'lines',
                         line=dict(color="rgb"+str(tableau20[13]),
                                   width=4),
                         name = "Bayesian with flat prior (MLE)",
                         hovertemplate = '<b>x: </b> <i> %{x} sampled neurons</i> <br><b>y: </b> <i>%{y} cm </i>'))

figb.add_trace(go.Scatter(x = sample_sizes, 
                         y = bay_flat_means.flatten(), 
                         mode = 'markers',
                         showlegend=False, 
                         line=dict(color="rgb"+str(tableau20[13])),
                         name = "Bayesian with flat prior (MLE) flatten",
                         hovertemplate = '<b>x: </b> <i> %{x} sampled neurons</i> <br><b>y: </b> <i>%{y} cm </i>',  
                         marker=dict(size=6,
                              line=dict(width=0.4,
                                        color='black'))))

# 2nd line 
figb.add_trace(go.Scatter(x = [5,10,15,20,25,30,35,40,45,50,55], 
                         y = np.mean(bay_mem_means,axis=1), 
                         mode = 'lines',
                         line=dict(color="rgb"+str(tableau20[6]),
                                   width=4),
                         name = "Bayesian decoder with memory",
                         hovertemplate = '<b>x: </b> <i> %{x} sampled neurons</i> <br><b>y: </b> <i>%{y} cm </i>'))

figb.add_trace(go.Scatter(x = sample_sizes, 
                         y = bay_mem_means.flatten(), 
                         mode = 'markers',
                         showlegend=False, 
                         line=dict(color="rgb"+str(tableau20[6])),
                         name = "Bayesian decoder with memory flatten",
                         hovertemplate = '<b>x: </b> <i> %{x} sampled neurons</i> <br><b>y: </b> <i>%{y} cm </i>',  
                         marker=dict(size=6,
                              line=dict(width=0.4,
                                        color='black'))))
# 3rd line 
figb.add_trace(go.Scatter(x = [5,10,15,20,25,30,35,40,45,50,55], 
                         y = mean_of_means,
                         mode = 'lines',
                         line=dict(color="rgb"+str(tableau20[0]),
                                   width=4),
                         name = "RNN model",
                         hovertemplate = '<b>x: </b> <i> %{x} sampled neurons</i> <br><b>y: </b> <i>%{y} cm </i>'))

figb.add_trace(go.Scatter(x = sample_sizes, 
                         y = mea_and_med[:,0], 
                         mode = 'markers',
                         showlegend=False, 
                         line=dict(color="rgb"+str(tableau20[1])),
                         name = "RNN model flatten",
                         hovertemplate = '<b>x: </b> <i> %{x} sampled neurons</i> <br><b>y: </b> <i>%{y} cm </i>',
                         marker=dict(size=6,
                              line=dict(width=0.4,
                                        color='black'))))

figb.update_layout(title = '(b)  Mean results with downsampled datasets',
                  title_x = 0.5, 
                  xaxis_title='Number of sampled neurons',
                  hovermode='closest',
                  xaxis=dict(range=[4,57], 
                             mirror=True,
                             ticks='outside',
                             showline=True,
                             tickvals = [5,10,15,20,25,30,35,40,45,50,55],
                             linecolor='#000', 
                             tickfont = dict(size=16)), 
                  yaxis_title='Mean error (cm)',
                  yaxis=dict(range=[9,50], 
                             mirror=True,
                             ticks='outside', 
                             showline=True,
                             tickvals = [0,5,10,15,20,25,30,35,40,45],
                             linecolor='#000',
                             tickfont = dict(size=16)),
                  legend=dict(x=0.37, 
                              y=0.85,
                              bordercolor="Gray",
                              borderwidth=1),
                  plot_bgcolor='#fff', 
                  width = 550, 
                  height = 400,
                  font = dict(size = 13),
                  margin=go.layout.Margin(l=50,
                                         r=50,
                                         b=60,
                                         t=35))

figb.update_traces()

plot(figb, filename = 'fig2_b.html', config = config)
# THEBELAB
display(HTML('fig2_b.html'))
# BINDER
# iplot(figb,config=config)

(c) Median results with downsampled datasets

%use Python3
import plotly.graph_objects as go
import plotly.tools as tls
from plotly.offline import plot, iplot, init_notebook_mode
from IPython.core.display import display, HTML
init_notebook_mode(connected = True)
config={'showLink': False, 'displayModeBar': False}


with open('train.pickle', 'rb') as f:
    tableau20, sample_sizes, bay_flat_means, bay_mem_means, mean_of_means, mea_and_med, bayes_flat_downsample_means, bayes_mem_downsample_means, bayes_flat_downsample_medians, bayes_mem_downsample_medians, errors, avg_errors, b_mem_errors = pickle.load(f, encoding='bytes') 

bay_flat_medians = bayes_flat_downsample_medians
bay_mem_medians = bayes_mem_downsample_medians    
mean_of_medians = np.mean(np.reshape(mea_and_med[:,1],(11,10)),axis=1)

figc = go.Figure()

# 1st line 
figc.add_trace(go.Scatter(x = [5,10,15,20,25,30,35,40,45,50,55], 
                         y = np.mean(bay_flat_medians,axis=1), 
                         mode = 'lines',
                         line=dict(color="rgb"+str(tableau20[13]),
                                   width=4),
                         name = "Bayesian with flat prior (MLE)",
                         hovertemplate = '<b>x: </b> <i> %{x} sampled neurons</i> <br><b>y: </b> <i>%{y} cm </i>'
                         ))

figc.add_trace(go.Scatter(x = sample_sizes, 
                         y = bay_flat_medians.flatten(), 
                         mode = 'markers',
                         showlegend=False, 
                         line=dict(color="rgb"+str(tableau20[13])),
                         name = "Bayesian with flat prior (MLE) flatten", 
                         hovertemplate = '<b>x: </b> <i> %{x} sampled neurons</i> <br><b>y: </b> <i>%{y} cm </i>',
                          marker=dict(size=6,
                              line=dict(width=0.4,
                                        color='black'))))

# 2nd line 
figc.add_trace(go.Scatter(x = [5,10,15,20,25,30,35,40,45,50,55], 
                         y = np.mean(bay_mem_medians,axis=1), 
                         mode = 'lines',
                         line=dict(color="rgb"+str(tableau20[6]),
                                   width=4),
                         name = "Bayesian decoder with memory",
                         hovertemplate = '<b>x: </b> <i> %{x} sampled neurons</i> <br><b>y: </b> <i>%{y} cm </i>'))

figc.add_trace(go.Scatter(x = sample_sizes, 
                         y = bay_mem_medians.flatten(), 
                         mode = 'markers',
                         showlegend=False, 
                         line=dict(color="rgb"+str(tableau20[6])),
                         name = "Bayesian decoder with memory flatten",
                         hovertemplate = '<b>x: </b> <i> %{x} sampled neurons</i> <br><b>y: </b> <i>%{y} cm </i>', 
                         marker=dict(size=6,
                              line=dict(width=0.4,
                                        color='black'))))
# 3rd line 
figc.add_trace(go.Scatter(x = [5,10,15,20,25,30,35,40,45,50,55], 
                         y = mean_of_medians,
                         mode = 'lines',
                         line=dict(color="rgb"+str(tableau20[0]),
                                   width=4),
                         name = "RNN model",
                         hovertemplate = '<b>x: </b> <i> %{x} sampled neurons</i> <br><b>y: </b> <i>%{y} cm </i>'))

figc.add_trace(go.Scatter(x = sample_sizes, 
                         y = mea_and_med[:,0], 
                         mode = 'markers',
                         showlegend=False, 
                         line=dict(color="rgb"+str(tableau20[1])),
                         name = "RNN model flatten",
                         hovertemplate = '<b>x: </b> <i> %{x} sampled neurons</i> <br><b>y: </b> <i>%{y} cm </i>',
                         marker=dict(size=6,
                              line=dict(width=0.4,
                                        color='black'))))

figc.update_layout(title = '(c)  Median results with downsampled datasets',
                  title_x = 0.5, 
                  xaxis_title='Number of sampled neurons',
                  hovermode='closest',
                  xaxis=dict(range=[4,57], 
                             mirror=True,
                             ticks='outside',
                             showline=True,
                             tickvals = [5,10,15,20,25,30,35,40,45,50,55],
                             linecolor='#000', 
                             tickfont = dict(size=16)), 
                  yaxis_title='Mean error (cm)',
                  yaxis=dict(range=[9,50], 
                             mirror=True,
                             ticks='outside', 
                             showline=True,
                             tickvals = [0,5,10,15,20,25,30,35,40,45],
                             linecolor='#000',
                             tickfont = dict(size=16)),
                  legend=dict(x=0.37, 
                              y=0.85,
                              bordercolor="Gray",
                              borderwidth=1),
                  plot_bgcolor='#fff', 
                  width = 550, 
                  height = 400,
                  font = dict(size = 13),
                  margin=go.layout.Margin(l=50,
                                         r=50,
                                         b=60,
                                         t=35))

figc.update_traces()

plot(figc, filename = 'fig2_c.html', config = config)
# THEBELAB
display(HTML('fig2_c.html'))
# BINDER
# iplot(figc,config=config)