Search
Fig 4. Analysis of the errors in function of location and neural activity for rat R2192

This code example is a Jupyter notebook with a Script of Scripts (SoS) workflow.

The calculations are written using Python 2.7 (from author's repo), and the interactive figures are written in Python 3.6 with Plotly.

Figure 4:

(a) The trajectory of the rat during the entire trial. Not all parts of the arena are visited with equal frequently. (b) The average size of errors made in different regions of space. Color of each hexagon depicts the average euclidean error of data points falling into the hexagon. More frequently visited areas (as seen from (a)) tend to have lower mean error. (c) Sum neural activity in different regions of space. For each data point we sum the spike counts of all 63 neurons in a 1400 ms period centered around the moment the location was recorded. The color of the hexagon corresponds to the average over all data points falling into the hexagon. Areas where sum neural activity is high have lower prediction error. (d) Prediction error of a coordinate decreases if the animal is closer to the wall perpendicular to that coordinate.

https://doi.org/10.1371/journal.pcbi.1006822.g004

%use Python2

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import binned_statistic
from sklearn.cross_validation import KFold
import math
import scipy.io as scio
import matplotlib.style
import matplotlib as mpl
mpl.style.use('classic')
%matplotlib inline
import pickle 


data = np.load("R2192_models/window_scan_R2192_1x1400_predictions_v1.npz")
y = data["targets"]

# These are the "Tableau 20" colors as RGB.    
tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),    
             (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),    
             (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),    
             (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),    
             (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]    
  
# Scale the RGB values to the [0, 1] range, which is the format matplotlib accepts.    
for i in range(len(tableau20)):    
    r, g, b = tableau20[i]    
    tableau20[i] = (r / 255., g / 255., b / 255.)
    
def mse(y, t, axis=2):
    return (np.square(y - t).mean(axis=axis).mean())

def mean_distance(y, t, axis=2):
    return np.mean(np.sqrt(np.sum((y - t)**2, axis=axis)))

def median_distance(y, t, axis=2):
    return np.median(np.sqrt(np.sum((y - t)**2, axis=axis)))

cvfolds = 10
data = np.load("R2192_models/window_scan_R2192_1x1400_predictions_v1.npz")

# print type(data), data.keys()

pred_y = data["preds"]
y = data["targets"]
targets = data["targets"]

# print pred_y.shape, targets.shape

seqlen = pred_y.shape[1]
pred_y = pred_y.reshape((-1, 2))

errors = np.sqrt(np.sum((pred_y - y)**2, axis=1))

# WHAT IS CHANCE LEVEL ACCURACY?
r_y = y.copy()
np.random.shuffle(r_y)


all_preds = []
all_errors =[]

# assumes there are predictions from 10 CV-runs in the folder R2192_models
for i in range(1,11):
    data = np.load("R2192_models/window_scan_R2192_1x1400_predictions_v"+str(i)+".npz")
    pr_i = data["preds"]
    tgts = data["targets"]
    assert np.all(tgts==targets) #should always be same
    
    all_preds.append(pr_i)
    errors_i = np.sqrt(np.sum((pr_i - y)**2, axis=1))
    all_errors.append(errors_i) # we record errors of each model individually, not the error made by ensemble model	

    
# print "we have in total error size: ", np.shape(all_preds), np.shape(all_errors), np.mean(np.ravel(all_errors))
avg_pred = np.mean(all_preds,axis=0) # this gives us the prediction of ensemble model
avg_errors = np.mean(all_errors,axis=0) #this gives us the average of errors for each location, not error of ensemble model!
median_errors = np.median(all_errors,axis=0)


def min_dist_in_train(node, nodes):
    nodes = np.asarray(nodes)
    deltas = nodes - node
    dist_2 = np.einsum('ij,ij->i', deltas, deltas)
    return np.min(dist_2)

def nearby_in_train(node, nodes):
    nodes = np.asarray(nodes)
    deltas = nodes - node
    dist_2 = np.einsum('ij,ij->i', deltas, deltas)
    ordered = sorted(dist_2)
    ii=0
    while ordered[ii]<10:
        ii+=1
    return ii

raw_y = np.loadtxt("data/R2192_1x1400_at35_step200_bin100-RAW_pos.dat")
raw_feat = np.loadtxt("data/R2192_1x1400_at35_step200_bin100-RAW_feat.dat")

closest_points=[]
count_nearby = []
activity=[]

for i, (rest_idx, test_idx) in enumerate(KFold(raw_y.shape[0], 10)):
    if np.min(test_idx)==0: #first fol left out
        rest_idx = np.array(sorted(rest_idx)[99:])
        test_idx = np.array(sorted(test_idx)[99:])
    elif np.max(test_idx)== raw_y.shape[0]-1: #last fold left out
        rest_idx = np.array(sorted(rest_idx)[99:])
        test_idx = np.array(sorted(test_idx)[99:])
    else:
        assert (np.min(test_idx)>np.min(rest_idx))
        assert (np.max(test_idx)<np.max(rest_idx))
        
        all_idx = np.array(range(raw_y.shape[0]))
        #print all_idx[99:np.min(test_idx)].shape,all_idx[np.max(test_idx)+100:].shape
        rest_idx = np.hstack((all_idx[99:np.min(test_idx)],all_idx[np.max(test_idx)+100:]))
        test_idx = np.array(sorted(test_idx)[99:])
#     print "Fold nr ",i," shapes ", test_idx.shape, rest_idx.shape     
    test = raw_y[test_idx]
    rest = raw_y[rest_idx]
    for loc in test:
        closest_points.append(min_dist_in_train(loc,rest))
        count_nearby.append(nearby_in_train(loc,rest))
    # find the lines in activity matrix, sum and append
    test_activity = np.sum(raw_feat[test_idx,:],axis=1)
    activity = np.concatenate([activity,test_activity])

xerrors = np.abs(avg_pred[:,0] - y[:,0])
yerrors = np.abs(avg_pred[:,1] - y[:,1])
bins = np.arange(0,100.1, 2)

all_1D_err = np.concatenate((xerrors,yerrors))
all_distances_to_wall = np.concatenate((50-np.abs(y[:,0]-50),50-np.abs(y[:,1]-50)))
statistic, bin_edges, binnumber = binned_statistic(all_distances_to_wall, all_1D_err, statistic='mean', bins=bins)

with open('train.pickle', 'wb') as f:
    pickle.dump([tableau20, all_preds, all_errors, activity, statistic, bin_edges, binnumber], f)

(a) Trajectory of R2192

%use Python3

import plotly.graph_objects as go
import plotly.tools as tls
from plotly.offline import plot, iplot, init_notebook_mode
from IPython.core.display import display, HTML
init_notebook_mode(connected = True)
config={'showLink': False, 'displayModeBar': False}
import numpy as np

data = np.load("R2192_models/window_scan_R2192_1x1400_predictions_v1.npz")
y = data["targets"]

figa = go.Figure()

# 1st line 
figa.add_trace(go.Scatter(x = y[:,0], 
                         y = y[:,1], 
                         mode = 'lines',
                         line=dict(color='blue',
                                   width=1),
                         name = "Trajectory of <br> R2192", 
                         hovertemplate = '<b>x: </b> <i> %{x} cm </i> <br> <b>y: </b> <i> %{y} cm </i>'))


figa.update_layout(title = '(a)',
                  title_x = 0.5, 
                  xaxis_title='X-coordinate (cm)',
                  xaxis=dict(range=[np.min(y[:,0]), np.max(y[:,0])], 
                             mirror=True,
                             ticks='outside',
                             showline=True,
                             linecolor='#000', 
                             tickfont = dict(size=16)), 
                  yaxis_title='Y-coordinate (cm)',
                  yaxis=dict(range=[np.min(y[:,1]), np.max(y[:,1])], 
                             mirror=True,
                             ticks='outside', 
                             showline=True,
                             linecolor='#000',
                             tickfont = dict(size=16)),
                  plot_bgcolor='#fff', 
                  width = 520, 
                  height = 500,
                  margin=go.layout.Margin(l=50,
                                          r=50,
                                          b=60,
                                          t=35), 
                  font = dict(size = 14))


plot(figa, filename = 'fig4_a.html', config = config)
# THEBELAB
display(HTML('fig4_a.html'))
# BINDER
# iplot(figa,config=config)

(b) Mean prediction error for different regions for R2192

%use Python3

import plotly.graph_objects as go
import plotly.tools as tls
from plotly.offline import plot, iplot, init_notebook_mode
from hexplot import  get_hexbin_attributes, pl_cell_color, make_hexagon, mpl_to_plotly
from IPython.core.display import display, HTML
init_notebook_mode(connected = True)
config={'showLink': False, 'displayModeBar': False}

with open('train.pickle', 'rb') as f:
    tableau20, all_preds, all_errors, activity, statistic, bin_edges, binnumber = pickle.load(f, encoding='bytes')

with open('HB_1.pickle', 'rb') as f:
    HB = pickle.load(f, encoding='bytes')
    
with open('centers_shapes.pickle1', 'rb') as f:
    centers, shapes = pickle.load(f, encoding='bytes')
    
_, _, _, counts = hexagon_vertices, offsets, mpl_facecolors, counts = get_hexbin_attributes(HB[0])
pl_algae = mpl_to_plotly(HB[0].cmap,11)
X, Y = zip(*centers)

#define  text to be  displayed on hovering the mouse over the cells
text = [f'x: {round(X[k],2)}<br>y: {round(Y[k],2)}<br>Counts: {int(counts[k])}' for k in range(len(X))]

trace = go.Scatter(
             x=list(X), 
             y=list(Y), 
             mode='markers',
             marker=dict(size=0.5, 
                         color=counts, 
                         colorscale=pl_algae, 
                         showscale=True,
                         colorbar=dict(
                                     thickness=20,  
                                     ticklen=4
                                     )),             
           text=text, 
           hoverinfo='text'
          )    

axis = dict(showgrid=False,
           showline=True,
           zeroline=False,
           ticklen=4 
           )
layout = go.Layout(title='Hexbin plot',
                   width=530, height=550,
                   xaxis=axis,
                   yaxis=axis,
                   hovermode='closest',
                   shapes=shapes,
                   plot_bgcolor='black')

figb = go.FigureWidget(data=[trace], layout = layout)

figb.update_layout(title = '(b)',
                  title_x = 0.5, 
                  xaxis_title='X-coordinate (cm)',
                  xaxis=dict(mirror=True,
                             ticks='outside',
                             showline=True,
                             linecolor='#000',
                             tickvals = [0,20,40,60,80,100], 
                             tickfont = dict(size=16)), 
                  yaxis_title='Y-coordinate (cm)',
                  yaxis=dict(mirror=True,
                             ticks='outside', 
                             showline=True,
                             linecolor='#000',
                             tickfont = dict(size=16)),
                  plot_bgcolor='#fff', 
                  width = 550, 
                  height = 495,
                  shapes=shapes,
                  hovermode='closest',
                  margin=go.layout.Margin(l=50,
                                          r=50,
                                          b=60,
                                          t=35),
                  font = dict(size = 14))

plot(figb, filename = 'fig4_b.html', config = config)
# THEBELAB
display(HTML('fig4_b.html'))
# BINDER
# iplot(figb,config=config)

(c) Sum neural activity in spikes in different regions for R2192

%use Python3

import plotly.graph_objects as go
import plotly.tools as tls
from plotly.offline import plot, iplot, init_notebook_mode
from hexplot import  get_hexbin_attributes, pl_cell_color, make_hexagon, mpl_to_plotly
from IPython.core.display import display, HTML
init_notebook_mode(connected = True)
config={'showLink': False, 'displayModeBar': False}

with open('train.pickle', 'rb') as f:
    tableau20, all_preds, all_errors, activity, statistic, bin_edges, binnumber = pickle.load(f, encoding='bytes')

with open('HB_2.pickle', 'rb') as f:
    HB = pickle.load(f, encoding='bytes')

with open('centers_shapes.pickle2', 'rb') as f:
    centers, shapes = pickle.load(f, encoding='bytes')

_, _, _, counts = get_hexbin_attributes(HB[0])
pl_algae = mpl_to_plotly(HB[0].cmap,11)
X, Y = zip(*centers)

#define  text to be  displayed on hovering the mouse over the cells
text = [f'x: {round(X[k],2)}<br>y: {round(Y[k],2)}<br>Counts: {int(counts[k])}' for k in range(len(X))]

trace = go.Scatter(
             x=list(X), 
             y=list(Y), 
             mode='markers',
             marker=dict(size=0.5, 
                         color=counts, 
                         colorscale=pl_algae, 
                         showscale=True,
                         colorbar=dict(
                                     thickness=20,  
                                     ticklen=4
                                     )),             
           text=text, 
           hoverinfo='text'
          )    

axis = dict(showgrid=False,
           showline=True,
           zeroline=False,
           ticklen=4 
           )

figc = go.FigureWidget(data=[trace], layout=layout)

figc.update_layout(title = '(c)',
                  title_x = 0.5, 
                  xaxis_title='X-coordinate (cm)',
                  xaxis=dict(mirror=True,
                             ticks='outside',
                             showline=True,
                             linecolor='#000',
                             tickvals = [0,20,40,60,80,100], 
                             tickfont = dict(size=16)), 
                  yaxis_title='Y-coordinate (cm)',
                  yaxis=dict(mirror=True,
                             ticks='outside', 
                             showline=True,
                             linecolor='#000',
                             tickfont = dict(size=16)),
                  plot_bgcolor='#fff', 
                  width = 550, 
                  height = 495,
                  shapes=shapes,
                  hovermode='closest',
                  margin=go.layout.Margin(l=50,
                                          r=50,
                                          b=60,
                                          t=35),
                  font = dict(size = 14))

plot(figc, filename = 'fig4_c.html', config = config)
# THEBELAB
display(HTML('fig4_c.html'))
# BINDER
# iplot(figc,config=config)

(d) Prediction error of a coordinate

%use Python3

import plotly.graph_objects as go
import plotly.tools as tls
from plotly.offline import plot, iplot, init_notebook_mode
from IPython.core.display import display, HTML
init_notebook_mode(connected = True)
config={'showLink': False, 'displayModeBar': False}

figd = go.Figure()

figd.add_trace(go.Bar(
    x=bin_edges[1:]-1.5,
    y=statistic,
    width = 0.85, 
    name='Mean prediction <br> error R2192',
    marker_color="rgb"+str(tableau20[0]), 
    hovertemplate = '<b>Distance (x): </b> <i> %{x} cm</i> <br> <b>Prediction error (y):</b> <i> %{y} </i>'
))

figd.update_layout(title = '(d)',
                  title_x = 0.5, 
                  xaxis_title = 'Distance to the nearest perpendicular wall (cm)',
                  xaxis=dict(range=[-0.5,25.4], 
                             mirror=True,
                             ticks='outside',
                             showline=True,
                             linecolor='#000',
                             tickvals = np.arange(0,24.1,2), 
                             tickfont = dict(size=15)),
                  yaxis_title='Prediction error of the coordinate',
                  yaxis=dict(range=[0,12.3], 
                             mirror=True,
                             ticks='outside', 
                             showline=True,
                             tickvals = np.arange(0,12.1,2), 
                             linecolor='#000',
                             tickfont = dict(size=15)),
                  plot_bgcolor='#fff', 
                  width = 550, 
                  height = 445,
                  font = dict(size = 14),
                  margin=go.layout.Margin(l=50,
                                         r=50,
                                         b=60,
                                        t=35),
                   bargap=0.25,
                   bargroupgap=0.1)

plot(figd, filename = 'fig4_d.html', config = config)
# THEBELAB
display(HTML('fig4_d.html'))
# BINDER
# iplot(figd,config=config)