Search
Fig 3. Spatial decoding across animals in 2D and 1D environments

This code example is a Jupyter notebook with a Script of Scripts (SoS) workflow.

The calculations are written using Python 2.7 (from author's repo), and the interactive figures are written in Python 3.6 with Plotly.

Figure 3:

(a-b) Decoding results in a 1m square environment. The RNN consistently outperforms the two Bayesian approaches in all 5 data sets. Mean and median errors across cross validation folds, respectively. (c-d) Decoding errors from a 600 cm long Z-shaped track. RNN consistently yields lower decoding errors than the Bayesian approaches, the difference is more marked when mean (c) as oppose to median (d) errors are considered.

https://doi.org/10.1371/journal.pcbi.1006822.g003

%use Python2

import matplotlib.pyplot as plt
import matplotlib.style
import matplotlib as mpl
mpl.style.use('classic')
import numpy as np
from scipy.io import loadmat
%matplotlib inline
import pickle

#set up nicer color scheme
tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),  
             (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),  
             (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),  
             (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),  
(188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]
# Rescale to values between 0 and 1 
for i in range(len(tableau20)):  
    r, g, b = tableau20[i]  
    tableau20[i] = (r / 255., g / 255., b / 255.)
    
# this file contains the perfromance of 10 models on R2192 "grep"-ed from the actual log files
f=open("R2192_grepped_predictions.log")
lines = f.readlines()
f.close()

# fill a dictionary where keys are timewindow sizes, filled with [mean, median] for 10 models
RNN_stats={}
for line in lines:
    pieces=line.split(" ")
    win_size = pieces[0].split("x")[1]
    win_size = int(win_size[:win_size.find("_")])
    median = float(pieces[-1])
    mean = float(pieces[-5][:-1])
    if win_size in RNN_stats.keys():
        RNN_stats[win_size].append([mean, median])
    else:
        RNN_stats[win_size]=[[mean, median]]

# Results with Bayes with flat prior (MLE)
# imported as a dictionary, each item in dictionary contains results for all 5 rats
# first rat is R2192 (index 0)
bay_field_dict = loadmat("Bayes_res/Fig1&3ab_decodingDataForOpenField.mat")
#bay_field_dict = loadmat("Bayes_res/2dDecodeAllRatsAllWindowsNaiveBayesNoTruncate.mat")
flat_medians=bay_field_dict['medianErr']
flat_means= bay_field_dict['meanErr']
flat_win = bay_field_dict['tWin2Test'].flatten()

# Results with Bayes with flat prior (MLE)
# imported as a dictionary, each item in dictionary contains results for all 5 rats
# first rat is R2192 (index 0)

# bay_field_dict_history_h5 = loadmat("Bayes_res/oldParams_2dDecodeFullBayesWithHistorySigma1History5.mat") #old params
bay_field_dict_history = loadmat("Bayes_res/Fig1&3ab_2dDecodeFullBayesWithHistorySigma1History15.mat")
#bay_field_dict_history = loadmat("Bayes_res/2dDecodeAllRatsAllWindowsBayesWithHistorySpeed1History15NoTruncate.mat")
memory_medians = bay_field_dict_history['medianErr']
memory_means = bay_field_dict_history['meanErr']
memory_win = bay_field_dict_history['tWin2Test'].flatten()

labels =["R2192", "R2198","R2336","R2337", "R2217"]
order_of_animals = [0,1,3,4,2]
# results from RNN exctracted from log files and averaged
means = [12.50484, 13.27721, 16.28584545, 14.36655,  14.4936]
mean_std = [0.3846042705, 0.3226775496, 0.269230256, 0.2945978247,0.2544971159]

medians = [10.3296,10.75701,13.09772, 11.3265, 11.73398]
median_std =[0.23188, 0.31606,0.19988, 0.15216, 0.24249]

# Results from flat Bayes (see the cell above), notice that order of animals is different, 
# we move R2217 to the end as it has least neurons
bayes =  np.min(flat_means, axis=0)[order_of_animals] #[ 15.82963073,  16.06929415, 18.81734775,  17.03691594,  17.86089428 ]
bayes_med =  np.min(flat_medians, axis=0)[order_of_animals]#[ 12.,   12.64911064 ,   14. ,  12.16552506, 14.14213562]

# Results from Bayes with memory (see the cell above), notice that order of animals is different, 
# we move R2217 to the end as has least neurons
memory_bayes =  np.min(memory_means, axis=0)[order_of_animals] #[ 15.46168191,  14.99576142,  18.26098815,  16.40828295, 16.5269506] 
memory_bayes_med = np.min(memory_medians, axis=0)[order_of_animals] #[ 11.3137085,  12., 13.11132817,  12.16552506, 12.64911064]

# Bayesian results with flat prior
bay_track_dict_flat = loadmat("Bayes_res/Fig3cd_decodingDataForLinearizedT-mazeDiscard1st25s.mat")
t_flat_medians = bay_track_dict_flat['medianErr']
t_flat_means = bay_track_dict_flat['meanErr']
t_flat_win = bay_track_dict_flat['tWin2Test'].flatten()

# Bayesian results with memory
#bay_track_dict_history_old = loadmat("Bayes_res/1dDecodeFullBayesWithHistorySigma5History15.mat") # old params
bay_track_dict_history = loadmat("Bayes_res/Fig3cd_1dDecodeFullBayesWithHistorySigma5History15ExcludeFirst25.mat")

t_memory_medians = bay_track_dict_history['medianErr']
t_memory_means = bay_track_dict_history['meanErr']
t_memory_win = bay_track_dict_history['tWin2Test'].flatten()


with open('train.pickle', 'wb') as f:
    pickle.dump([means, mean_std, memory_bayes, bayes, labels, flat_means, order_of_animals, flat_medians, memory_means, memory_medians, tableau20], f)

(a) Mean errors in 2D decoding task

%use Python3

import plotly.graph_objects as go
import plotly.tools as tls
from plotly.offline import plot, iplot, init_notebook_mode
from IPython.core.display import display, HTML
import numpy as np
init_notebook_mode(connected = True)
config={'showLink': False, 'displayModeBar': False}

with open('train.pickle', 'rb') as f:
    means, mean_std, memory_bayes, bayes, labels, flat_means, order_of_animals, flat_medians, memory_means, memory_medians, tableau20  = pickle.load(f, encoding='bytes')
    
labels =["R2192", "R2198","R2336","R2337", "R2217"]
order_of_animals = [0,1,3,4,2]
# results from RNN exctracted from log files and averaged
means = [12.50484, 13.27721, 16.28584545, 14.36655,  14.4936]
mean_std = [0.3846042705, 0.3226775496, 0.269230256, 0.2945978247,0.2544971159]

medians = [10.3296,10.75701,13.09772, 11.3265, 11.73398]
median_std =[0.23188, 0.31606,0.19988, 0.15216, 0.24249]

# Results from flat Bayes (see the cell above), notice that order of animals is different, 
# we move R2217 to the end as it has least neurons
bayes =  np.min(flat_means, axis=0)[order_of_animals] #[ 15.82963073,  16.06929415, 18.81734775,  17.03691594,  17.86089428 ]
bayes_med =  np.min(flat_medians, axis=0)[order_of_animals]#[ 12.,   12.64911064 ,   14. ,  12.16552506, 14.14213562]

# Results from Bayes with memory (see the cell above), notice that order of animals is different, 
# we move R2217 to the end as has least neurons
memory_bayes =  np.min(memory_means, axis=0)[order_of_animals] #[ 15.46168191,  14.99576142,  18.26098815,  16.40828295, 16.5269506] 
memory_bayes_med = np.min(memory_medians, axis=0)[order_of_animals] #[ 11.3137085,  12., 13.11132817,  12.16552506, 12.64911064]

figa = go.Figure()
indx=np.arange(5)
width_0=0.3

figa.add_trace(go.Bar(
    x=labels,
    y=means,
    error_y = dict(array = mean_std,
                   color = 'black', 
                   width = 4, 
                   thickness = 3,
                   value = 8
                   ),
    name='RNN',
    marker_color="rgb"+str(tableau20[0]), 
    width = width_0,
    offset = -0.4,
    text = means,
    textfont=dict(
        size=14,
        color="rgb"+str(tableau20[0])
    ),
    hovertemplate = '<b> Label: </b> <i> %{x} </i>, <br> <b> Mean Error: </b> <i> %{y: .2f} cm </i>'
))

figa.add_trace(go.Bar(
    x=labels,
    y=memory_bayes,
    name='Bayesian memory',
    marker_color='#a43032',
    width = width_0 / 1.5,
    offset = 0.0,
    text = memory_bayes,
    textfont=dict(
        size=18,
        color="black"
    ),
    hovertemplate = '<b> Label: </b> <i> %{x} </i>, <br> <b> Mean Error: </b> <i> %{y: .2f} cm </i>'
))

figa.add_trace(go.Bar(
    x=labels,
    y=bayes,
    name='MLE',
    marker_color="rgb"+str(tableau20[6]),
    width = width_0 / 1.5,
    offset = 0.205,
    text = bayes,
    textfont=dict(
        size=18,
        color="rgb"+str(tableau20[6])
    ),
    hovertemplate = '<b> Label: </b> <i> %{x} </i>, <br> <b> Mean Error: </b> <i> %{y: .2f} cm </i>'
))


figa.update_layout(title = '(a)   Mean errors in 2D decoding task',
                  title_x = 0.5, 
                  xaxis=dict(range=[-0.45,4.5], 
                             mirror='all',
                             ticks='outside',
                             showline=True,
                             linecolor='#000',       
                             tickfont = dict(size=16)),
                  yaxis_title='Mean error (cm)',
                  yaxis=dict(range=[0,24.5], 
                             mirror=True,
                             ticks='outside', 
                             showline=True, 
                             linecolor='#000',
                             tickfont = dict(size=16)),
                  legend=dict(orientation = 'h',
                              x = 0.125, 
                              y = 1,
                              bordercolor="Gray",
                              borderwidth=1),
                  plot_bgcolor='#fff', 
                  width = 550, 
                  height = 410,
                  font = dict(size = 13),
                  margin=go.layout.Margin(l=50,
                                         r=50,
                                         b=60,
                                        t=35),
                   bargap=0.25,
                   bargroupgap=0.1)

figa.update_traces(texttemplate='%{text:.2f}', textposition='outside')

plot(figa, filename = 'fig3_a.html', config = config)
# THEBELAB
display(HTML('fig3_a.html'))
# BINDER
# iplot(figa,config=config)

(b) Median errors in 2D decoding task

%use Python3

import plotly.graph_objects as go
import plotly.tools as tls
from plotly.offline import plot, iplot, init_notebook_mode
from IPython.core.display import display, HTML
init_notebook_mode(connected = True)
config={'showLink': False, 'displayModeBar': False}

with open('train.pickle', 'rb') as f:
    means, mean_std, memory_bayes, bayes, labels, flat_means, order_of_animals, flat_medians, memory_means, memory_medians, tableau20  = pickle.load(f, encoding='bytes')

    
labels =["R2192", "R2198","R2336","R2337", "R2217"]
#results from RNN
means = [12.50484, 13.27721, 16.28584545, 14.36655,  14.4936]
mean_std = [0.3846042705, 0.3226775496, 0.269230256, 0.2945978247,0.2544971159]
medians = [10.3296,10.75701,13.09772, 11.3265, 11.73398]
median_std =[0.23188, 0.31606,0.19988, 0.15216, 0.24249]

# Results from flat Bayes (see the cell above), notice that order of animals is different, 
# we move R2217 to the end as it has least neurons
bayes =  np.min(flat_means, axis=0)[order_of_animals] #[ 15.82963073,  16.06929415, 18.81734775,  17.03691594,  17.86089428 ]
bayes_med =  np.min(flat_medians, axis=0)[order_of_animals]#[ 12.,   12.64911064 ,   14. ,  12.16552506, 14.14213562]

# Results from Bayes with memory (see the cell above), notice that order of animals is different, 
# we move R2217 to the end as has least neurons
memory_bayes =  np.min(memory_means, axis=0)[order_of_animals] #[ 15.46168191,  14.99576142,  18.26098815,  16.40828295, 16.5269506] 
memory_bayes_med = np.min(memory_medians, axis=0)[order_of_animals] #[ 11.3137085,  12., 13.11132817,  12.16552506, 12.64911064]


figb = go.Figure()
indx=np.arange(5)
width_0=0.3

figb.add_trace(go.Bar(
    x=labels,
    y=medians,
    error_y = dict(array = median_std,
                   color = 'black', 
                   width = 4, 
                   thickness = 3,
                   value = 8
                   ),
    name='RNN',
    marker_color="rgb"+str(tableau20[0]), 
    width = width_0,
    offset = -0.4,
    text = medians,
    textfont=dict(
        size=14,
        color="rgb"+str(tableau20[0])
    ),
    hovertemplate = '<b> Label: </b> <i> %{x} </i>, <br> <b> Mean Error: </b> <i> %{y: .2f} cm </i>'
))

figb.add_trace(go.Bar(
    x=labels,
    y=memory_bayes_med,
    name='Bayesian memory',
    marker_color='#a43032',
    width = width_0 / 1.5,
    offset = 0.0,
    text = memory_bayes_med,
    textfont=dict(
        size=18,
        color="black"
    ),
    hovertemplate = '<b> Label: </b> <i> %{x} </i>, <br> <b> Mean Error: </b> <i> %{y: .2f} cm </i>'
))

figb.add_trace(go.Bar(
    x=labels,
    y=bayes_med,
    name='MLE',
    marker_color="rgb"+str(tableau20[6]),
    width = width_0 / 1.5,
    offset = 0.205,
    text = bayes_med,
    textfont=dict(
        size=18,
        color="rgb"+str(tableau20[6])
    ),
    hovertemplate = '<b> Label: </b> <i> %{x} </i>, <br> <b> Mean Error: </b> <i> %{y: .2f} cm </i>'
))


figb.update_layout(title = '(b)   Median errors in 2D decoding task',
                  title_x = 0.5, 
                  xaxis=dict(range=[-0.45,4.5], 
                             mirror='all',
                             ticks='outside',
                             showline=True,
                             linecolor='#000',       
                             tickfont = dict(size=16)),
                  yaxis_title='Mean error (cm)',
                  yaxis=dict(range=[0,24.5], 
                             mirror=True,
                             ticks='outside', 
                             showline=True, 
                             linecolor='#000',
                             tickfont = dict(size=16)),
                  legend=dict(orientation = 'h',
                              x = 0.125, 
                              y = 1,
                              bordercolor="Gray",
                              borderwidth=1),
                  plot_bgcolor='#fff', 
                  width = 550, 
                  height = 410,
                  font = dict(size = 13),
                  margin=go.layout.Margin(l=50,
                                         r=50,
                                         b=60,
                                        t=35),
                   bargap=0.25,
                   bargroupgap=0.1)

figb.update_traces(texttemplate='%{text:.2f}', textposition='outside')

plot(figb, filename = 'fig3_b.html', config = config)
# THEBELAB
display(HTML('fig3_b.html'))
# BINDER
# iplot(figb,config=config)

(c) Mean errors in 1D decoding task

%use Python3

import plotly.graph_objects as go
import plotly.tools as tls
from plotly.offline import plot, iplot, init_notebook_mode
from IPython.core.display import display, HTML
init_notebook_mode(connected = True)
config={'showLink': False, 'displayModeBar': False}

with open('train.pickle', 'rb') as f:
    means, mean_std, memory_bayes, bayes, labels, flat_means, order_of_animals, flat_medians, memory_means, memory_medians, tableau20  = pickle.load(f, encoding='bytes')

     
#results in 1D
labels =["R2192", "R2198","R2336","R2337", "R2217"]

# Order of animals is changed compared to indexing in the .mat files
median = [7.09, 7.86, 7.19, 7.39, 12.21]
means = [ 14.508, 22.273, 22.724,21.087, 23.296]
mean_std = [3.7650, 1.292,3.287, 0.617, 2.003]
median_std =[0.37, 0.36, 0.46,0.49, 1.30] 

# Order of animals is changed compared to indexing in the .mat files
bayes_med = [8,8.00001,8,10,24] # 00001 is added to identify the value later in plotting
bayes_mean = [ 45.21788618, 52.29189189, 63.53548387, 69.39569892, 73.98571429]

# Order of animals is changed compared to indexing in the .mat files
memory_bayes_med = [8.001,10.001,8.001,10.001,24.001]
memory_bayes_mean = [ 44.11041667, 65.38104265, 59.35448718, 66.03916084, 75.1372093]


figc = go.Figure()
indx=np.arange(5)
width_0=0.3

figc.add_trace(go.Bar(
    x=labels,
    y=means,
    error_y = dict(array = mean_std,
                   color = 'black', 
                   width = 4, 
                   thickness = 3,
                   value = 8
                   ),
    name='RNN',
    marker_color="rgb"+str(tableau20[0]), 
    width = width_0,
    offset = -0.4,
    text = means,
    textfont=dict(
        size=14,
        color="rgb"+str(tableau20[0])
    ),
    hovertemplate = '<b> Label: </b> <i> %{x} </i>, <br> <b> Mean Error: </b> <i> %{y: .2f} cm </i>'
))

figc.add_trace(go.Bar(
    x=labels,
    y=memory_bayes_mean,
    name='Bayesian memory',
    marker_color="#a43032",
    width = width_0 / 1.5,
    offset = 0.0,
    text = memory_bayes_mean,
    textfont=dict(
        size=18,
        color="black"
    ),
    hovertemplate = '<b> Label: </b> <i> %{x} </i>, <br> <b> Mean Error: </b> <i> %{y: .2f} cm </i>'
))

figc.add_trace(go.Bar(
    x=labels,
    y=bayes_mean,
    name='MLE',
    marker_color="rgb"+str(tableau20[6]),
    width = width_0 / 1.5,
    offset = 0.205,
    text = bayes_mean,
    textfont=dict(
        size=18,
        color="rgb"+str(tableau20[6])
    ),
    hovertemplate = '<b> Label: </b> <i> %{x} </i>, <br> <b> Mean Error: </b> <i> %{y: .2f} cm </i>'
))


figc.update_layout(title = '(c)   Mean errors in 1D decoding task',
                  title_x = 0.5, 
                  xaxis=dict(range=[-0.45,4.5], 
                             mirror='all',
                             ticks='outside',
                             showline=True,
                             linecolor='#000',       
                             tickfont = dict(size=16)),
                  yaxis_title='Mean error (cm)',
                  yaxis=dict(range=[0,90], 
                             mirror=True,
                             ticks='outside', 
                             showline=True, 
                             linecolor='#000',
                             tickfont = dict(size=16)),
                  legend=dict(orientation = 'h',
                              x = 0.125, 
                              y = 1,
                              bordercolor="Gray",
                              borderwidth=1),
                  plot_bgcolor='#fff', 
                  width = 550, 
                  height = 410,
                  font = dict(size = 13),
                  margin=go.layout.Margin(l=50,
                                         r=50,
                                         b=60,
                                        t=35),
                   bargap=0.25,
                   bargroupgap=0.1)

figc.update_traces(texttemplate='%{text:.2f}', textposition='outside')

plot(figc, filename = 'fig3_c.html', config = config)
# THEBELAB
display(HTML('fig3_c.html'))
# BINDER
# iplot(figc,config=config)

(d) Median errors in 1D decoding task

%use Python3

import plotly.graph_objects as go
import plotly.tools as tls
from plotly.offline import plot, iplot, init_notebook_mode
from IPython.core.display import display, HTML
init_notebook_mode(connected = True)
config={'showLink': False, 'displayModeBar': False}
    
labels =["R2192", "R2198","R2336","R2337", "R2217"]

# results from RNN
median = [7.09, 7.86, 7.19, 7.39, 12.21]
means = [ 14.508, 22.273, 22.724,21.087, 23.296]
mean_std = [3.7650, 1.292, 6.9994, 0.62, 2.003]
median_std =[0.37, 0.36, 0.46, 0.49, 1.30] 

# Order of animals is changed compared to indexing in the .mat files
bayes_med = [8,8.00001,8,10,24] # 00001 is added to identify the value later in plotting
bayes_mean = [ 45.21788618, 52.29189189, 63.53548387, 69.39569892, 73.98571429]

# Order of animals is changed compared to indexing in the .mat files
memory_bayes_med = [8.001,10.001,8.001,10.001,24.001]
memory_bayes_mean = [ 44.11041667, 65.38104265, 59.35448718, 66.03916084, 75.1372093]

figd = go.Figure()

width_0=0.3

figd.add_trace(go.Bar(
    x=labels,
    y=median,
    error_y = dict(array = median_std,
                   color = 'black', 
                   width = 4, 
                   thickness = 3,
                   value = 8
                   ),
    name='RNN',
    marker_color="rgb"+str(tableau20[0]), 
    width = width_0,
    offset = -0.4,
    text = median,
    textfont=dict(
        size=14,
        color="rgb"+str(tableau20[0])
    ),
    hovertemplate = '<b> Label: </b> <i> %{x} </i>, <br> <b> Mean Error: </b> <i> %{y: .2f} cm </i>'
))

figd.add_trace(go.Bar(
    x=labels,
    y=memory_bayes_med,
    name='Bayesian memory',
    marker_color="#a43032",
    width = width_0 / 1.5,
    offset = 0.0,
    text = memory_bayes_med,
    textfont=dict(
        size=18,
        color="black"
    ),
    hovertemplate = '<b> Label: </b> <i> %{x} </i>, <br> <b> Mean Error: </b> <i> %{y: .2f} cm </i>'
))

figd.add_trace(go.Bar(
    x=labels,
    y=bayes_med,
    name='MLE',
    marker_color="rgb"+str(tableau20[6]),
    width = width_0 / 1.5,
    offset = 0.205,
    text = bayes_med,
    textfont=dict(
        size=18,
        color="rgb"+str(tableau20[6])
    ),
    hovertemplate = '<b> Label: </b> <i> %{x} </i>, <br> <b> Mean Error: </b> <i> %{y: .2f} cm </i>'
))


figd.update_layout(title = '(d)   Median errors in 1D decoding task',
                  title_x = 0.5, 
                  xaxis=dict(range=[-0.45,4.5], 
                             mirror='all',
                             ticks='outside',
                             showline=True,
                             linecolor='#000',       
                             tickfont = dict(size=16)),
                  yaxis_title='Mean error (cm)',
                  yaxis=dict(range=[0,30.5], 
                             mirror=True,
                             ticks='outside', 
                             showline=True, 
                             linecolor='#000',
                             tickfont = dict(size=16)),
                 legend=dict(orientation = 'h',
                              x = 0.125, 
                              y = 1,
                              bordercolor="Gray",
                              borderwidth=1),
                  plot_bgcolor='#fff', 
                  width = 550, 
                  height = 410,
                  font = dict(size = 13),
                  margin=go.layout.Margin(l=50,
                                         r=50,
                                         b=60,
                                        t=35),
                   bargap=0.25,
                   bargroupgap=0.1)

figd.update_traces(texttemplate='%{text:.2f}', textposition='outside')

plot(figd, filename = 'fig3_d.html', config = config)
# THEBELAB
display(HTML('fig3_d.html'))
# BINDER
# iplot(figd,config=config)