Search
Fig 1. Accurate decoding of position with a RNN

This code example is a Jupyter notebook with Script of Scripts (SoS) workflow.

The calculations are written using Python 2.7 (from author's repo), and the interactive figures are written in Python 3.6 with Plotly.

Figure 1:

Location decoding errors based on CA1 neural data recorded from 1m square open field environment as a function of time window size. (a) shows mean error and (b) median error. Blue lines represent errors from the RNN decoder and red lines from Bayesian approaches. Results for the RNN approach are averaged over different independent realizations of the training algorithm. Black dots depict the mean/median error of each individual model. Results shown are for animal R2192.

https://doi.org/10.1371/journal.pcbi.1006822.g001

%use Python2 
import matplotlib.style
import matplotlib as mpl
mpl.style.use('classic')
import numpy as np
from scipy.io import loadmat
import pickle


#set up nicer color scheme
tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),  
             (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),  
             (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),  
             (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),  
(188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]
# Rescale to values between 0 and 1 
for i in range(len(tableau20)):  
    r, g, b = tableau20[i]  
    tableau20[i] = (r / 255., g / 255., b / 255.)
    
# this file contains the perfromance of 10 models on R2192 "grep"-ed from the actual log files
f=open("R2192_grepped_predictions.log")
lines = f.readlines()
f.close()

# fill a dictionary where keys are timewindow sizes, filled with [mean, median] for 10 models
RNN_stats={}
for line in lines:
    pieces=line.split(" ")
    win_size = pieces[0].split("x")[1]
    win_size = int(win_size[:win_size.find("_")])
    median = float(pieces[-1])
    mean = float(pieces[-5][:-1])
    if win_size in RNN_stats.keys():
        RNN_stats[win_size].append([mean, median])
    else:
        RNN_stats[win_size]=[[mean, median]]
# print RNN_stats[1400]

# Results with Bayes with flat prior (MLE)
# imported as a dictionary, each item in dictionary contains results for all 5 rats
# first rat is R2192 (index 0)
bay_field_dict = loadmat("Bayes_res/Fig1&3ab_decodingDataForOpenField.mat")
#bay_field_dict = loadmat("Bayes_res/2dDecodeAllRatsAllWindowsNaiveBayesNoTruncate.mat")
# print bay_field_dict.keys()
flat_medians=bay_field_dict['medianErr']
flat_means= bay_field_dict['meanErr']
flat_win = bay_field_dict['tWin2Test'].flatten()
# print flat_win.shape, flat_means[:,0].shape
# print "\n for Table 1: ", bay_field_dict["bstMean"]

# """ for Table 1:  [[ 15.82963073  16.06929415  17.86089428  18.81734775  17.03691594]
#  [  2.8          3.8          2.8          2.8          3.4       ]]"""

# Results with Bayes with flat prior (MLE)
# imported as a dictionary, each item in dictionary contains results for all 5 rats
# first rat is R2192 (index 0)

# bay_field_dict_history_h5 = loadmat("Bayes_res/oldParams_2dDecodeFullBayesWithHistorySigma1History5.mat") #old params
bay_field_dict_history = loadmat("Bayes_res/Fig1&3ab_2dDecodeFullBayesWithHistorySigma1History15.mat")
#bay_field_dict_history = loadmat("Bayes_res/2dDecodeAllRatsAllWindowsBayesWithHistorySpeed1History15NoTruncate.mat")
# print bay_field_dict_history.keys()
memory_medians = bay_field_dict_history['medianErr']
memory_means = bay_field_dict_history['meanErr']
memory_win = bay_field_dict_history['tWin2Test'].flatten()

# print memory_means[:,0]
# print memory_medians[:,0]
# print np.min(memory_medians, axis=0)

# print "\n for Table 1: ", bay_field_dict_history["bstMean"]
# """ for Table 1:  [[ 15.46168191  14.99576142  16.5269506   18.26098815  16.40828295]
#  [  2.           1.8          2.8          2.6          3.4       ]]"""

average = []
# draw individual model's performance as dots
for size in sorted(RNN_stats.keys()):
        means = np.array(RNN_stats[size])[:,0]
        average.append(np.mean(means))

# print average, "\n",flat_means[:,0]
a = np.array(memory_medians)

with open('train.pickle', 'wb') as f:
    pickle.dump([size, means, RNN_stats, flat_win, flat_means, flat_medians, memory_medians, tableau20, memory_win, memory_means, average], f)

(a) Mean prediction error for R2192 in function of timewindow size

%use Python3
import plotly.graph_objects as go
import plotly.tools as tls
from plotly.offline import plot, iplot, init_notebook_mode
from IPython.core.display import display, HTML
init_notebook_mode(connected = True)
import numpy as np
config={'showLink': False, 'displayModeBar': False}
 
with open('train.pickle', 'rb') as f:
    size, means, RNN_stats, flat_win, flat_means, flat_medians, memory_medians, tableau20, memory_win, memory_means, average = pickle.load(f, encoding='bytes')
    
    
figa = go.Figure()

average = []
# draw individual model's performance as dots
for size in sorted(RNN_stats.keys()):
        means = np.array(RNN_stats[size])[:,0]
        figa.add_trace(go.Scatter(x = [size]*len(means), 
                                 y = means,
                                 mode = 'markers',
                                 showlegend=False, 
                                 name = "RNN stats",
                                 line = dict(color="black")))
        average.append(np.mean(means))

figa.add_trace(go.Scatter(x = flat_win*1000, 
                         y = flat_means[:,0], 
                         mode = 'lines',
                         line=dict(color="rgb"+str(tableau20[6]),
                                   width=4,
                                   dash="dash"),
                         name = "Bayesian with flat prior (MLE)"))

figa.add_trace(go.Scatter(x = memory_win*1000, 
                         y = memory_means[:,0], 
                         mode = 'lines',
                         line=dict(color="rgb"+str(tableau20[6]),
                                   width=4),
                         name = "Bayesian decoder with memory"))

figa.add_trace(go.Scatter(x = sorted(RNN_stats.keys()), 
                         y = average,
                         mode = 'lines',
                         line=dict(color="rgb"+str(tableau20[0]),
                                   width=4),
                         name = "RNN decoder"))


figa.update_layout(title = '(a) Mean errors with different window size',
                  title_x = 0.5, 
                  xaxis_title='Time window size (ms)',
                  xaxis=dict(range=[175,4025], 
                             mirror=True,
                             ticks='outside',
                             showline=True,
                             linecolor='#000',
                             tickfont = dict(size=16)),
                  yaxis_title='Mean error (cm)',
                  yaxis=dict(range=[9,27], 
                             mirror=True,
                             ticks='outside', 
                             showline=True, 
                             linecolor='#000',
                             tickfont = dict(size=16)),
                  legend=dict(x=0.37, 
                              y=0.85,
                              bordercolor="Gray",
                              borderwidth=1),
                  plot_bgcolor='#fff', 
                  width = 550, 
                  height = 400,
                  font = dict(size = 13),
                  margin=go.layout.Margin(l=50,
                                         r=50,
                                         b=60,
                                         t=35))

plot(figa, filename = 'fig1_a.html', config = config)
# THEBELAB
display(HTML('fig1_a.html'))
# BINDER
# iplot(figa,config=config)

(b) Median prediction error for R2192 in function of timewindow size

%use Python3
import plotly.graph_objects as go
import plotly.tools as tls
from plotly.offline import plot, iplot, init_notebook_mode
from IPython.core.display import display, HTML
import numpy as np
init_notebook_mode(connected = True)
config={'showLink': False, 'displayModeBar': False}

with open('train.pickle', 'rb') as f:
    size, means, RNN_stats, flat_win, flat_means, flat_medians, memory_medians, tableau20, memory_win, memory_means, average = pickle.load(f, encoding='bytes')


figb = go.Figure()

average = []
# draw individual model's performance as dots
for size in sorted(RNN_stats.keys()):
        medians = np.array(RNN_stats[size])[:,1]
        figb.add_trace(go.Scatter(x = [size]*len(medians), 
                                 y = medians,
                                 mode = 'markers',
                                 showlegend=False,
                                 name = "RNN stats",
                                 line = dict(color="black")))
        average.append(np.mean(medians))

figb.add_trace(go.Scatter(x = flat_win*1000, 
                         y = flat_medians[:,0], 
                         mode = 'lines',
                         line=dict(color="rgb"+str(tableau20[6]),
                                   width=4,
                                   dash="dash"),
                         name = "Bayesian with flat piror (MLE)"))

figb.add_trace(go.Scatter(x = memory_win*1000, 
                         y = memory_medians[:,0], 
                         mode = 'lines',
                         line=dict(color="rgb"+str(tableau20[6]),
                                   width=4),
                         name = "Bayesian decoder with memory"))

figb.add_trace(go.Scatter(x = sorted(RNN_stats.keys()), 
                         y = average,
                         mode = 'lines',
                         line=dict(color="rgb"+str(tableau20[0]),
                                   width=4),
                         name = "RNN decoder"))


figb.update_layout(title = '(b) Median errors with different window size',
                  title_x = 0.5, 
                  xaxis_title='Time window size (ms)',
                  xaxis=dict(range=[175,4025], 
                             mirror=True,
                             ticks='outside',
                             showline=True,
                             linecolor='#000',       
                             tickfont = dict(size=16)),
                  yaxis_title='Mean error (cm)',
                  yaxis=dict(range=[9,27], 
                             mirror=True,
                             ticks='outside', 
                             showline=True, 
                             linecolor='#000',
                             tickfont = dict(size=16)),
                  legend=dict(x=0.37, 
                              y=0.85,
                              bordercolor="Gray",
                              borderwidth=1),
                  plot_bgcolor='#fff', 
                  width = 550, 
                  height = 400,
                  font = dict(size = 13),
                  margin=go.layout.Margin(l=50,
                                         r=50,
                                         b=60,
                                         t=35))

plot(figb, filename = 'fig1_b.html', config = config)
# THEBELAB
display(HTML('fig1_b.html'))
# BINDER
# iplot(figb,config=config)