Search
Supplemental Figure - Supporting Information

Supplemental Figure: What if $n_a \neq n_b$ ?

Reviewers made a good suggestion to show the case when $n_a$ is much larger than $n_b$. What happens in this case?

They also asked what happens when the population sizes are unequal, and their number is not 60.

We therefore consider conditions:

  • 1 - $n_a = n_b = 40$
  • 2 - $n_a=10, n_b=40$
  • A - total repertoire sizes are each 60
  • B - one repertoire of size 60 and the other of size 120

We show plots for 1A, 1B, 2A, 2B to show the various combinations.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
from scipy.stats import hypergeom
from scipy.stats import binned_statistic as binsta
from scipy.special import logsumexp
from util import *
import palettable as pal
clrx = pal.cartocolors.qualitative.Prism_10.mpl_colors
clr = tuple(x for n,x in enumerate(clrx) if n in [1,2,4,5,6])
clr2 = pal.cartocolors.sequential.agSunset_7.mpl_colors
import matplotlib.gridspec as gridspec
import matplotlib.patches as patches

import plotly.graph_objects as go
import plotly.tools as tls
from plotly.offline import plot, iplot, init_notebook_mode
from IPython.core.display import display, HTML
init_notebook_mode(connected = True)
config={'showLink': False, 'displayModeBar': False}

# CCP, the Coupon Collector's Problem
def ccp_sample(c,pool=60):
    return len(set(np.random.choice(pool,c)))

# Draw overlap
def nab_sample(s,na,nb,pool=60):
    sa = np.random.hypergeometric(s,pool-s,na)
    nab = np.random.hypergeometric(sa,pool-sa,nb)
    return nab

# Overlap between two PCRs of depth c and overlap s
def pcr_sample(c,s):
    na = ccp_sample(c)
    nb = ccp_sample(c)
    return nab_sample(s,na,nb),na,nb

# Draw na and nb samples from two populations of size pool_a and pool_b, with true overlap s
# and return empirical overlap between na and nb
# note that this is basically the same as nab_sample, but with two different size pools!
def nab_sample_unequal(s,na,nb,pool_a,pool_b):
    sa = np.random.hypergeometric(s,pool_a-s,na)
    nab = np.random.hypergeometric(sa,pool_b-sa,nb)
    return nab


def p_ccp(c, pool=60):
    p = np.zeros([c+1,pool+1])
    p[0,0] = 1;
    for row in range(1,c+1):
        for k in range(1,np.min([row+2,pool+1])):
            p[row,k] = p[row-1,k]*k/pool + p[row-1,k-1]*(1-(k-1)/pool)
    return p[-1,:]

def p_overlap(na,nb,nab,pool=60):
    p_s = np.zeros(pool+1)
    # reference: hypergeom.pmf(outcome, Total, hits, Draws, loc=0)
    for s in np.arange(pool+1):
        # p_sa is the probability that we'd get sa from the overlap (s), just in na draws of a
        p_sa = hypergeom.pmf(np.arange(pool+1),pool,s,na)
        # p_nab_given_sa is the probability of getting that nab, given sa
        p_nab_given_sa = hypergeom.pmf(nab,pool,np.arange(pool+1),nb)
        p_s[s] = np.dot(p_sa,p_nab_given_sa)
    return p_s/np.sum(p_s)

def e_overlap(na,nb,nab,pool=60):
    p_s = p_overlap(na,nb,nab,pool=pool)
    return np.dot(np.arange(pool+1),p_s)


def credible_interval(na,nb,nab,pct=90,pool=60):
    p_s = p_overlap(na,nb,nab,pool=pool)
    cdf = np.cumsum(p_s)
    ccdf = np.flipud(np.cumsum(np.flipud(p_s)))
    # adjust for fractions vs percents; put everything as a fraction
    if pct > 1:
        pct = pct/100
    cutoff = (1-pct)/2
    # get the lower bound. 
    # it's the first index at which cdf ≥ cutoff
    try:
        lower = np.where(cdf >= cutoff)[0][0]
    except IndexError:
        lower = 0
    # get the upper bound
    # it's the first index at which ccdf ≥ 0.05
    try:
        upper = np.where(ccdf >= cutoff)[0][-1]
    except IndexError:
        upper=pool
    expectation = np.dot(np.arange(pool+1),p_s)
    # Sanity and indexing check: uncomment this line to see true tail probability ≤ 0.05
    # print([cdf[lower-1],(1-ccdf[upper+1])])
    return lower,expectation,upper


def p_nab_given_c(s,c,pool=60):
    pna = p_ccp(c)
    pnb = p_ccp(c)
    nas = np.arange(1,len(pna))
    nbs = np.arange(1,len(pnb))
    p_gen = np.zeros([pool+1,pool+1,pool+1])
    for na in nas:
        p_sa = hypergeom.pmf(np.arange(pool+1),pool,s,na)
        for nb in nbs:
            pna_pnb = pna[na] * pnb[nb]
            for nab in range(0,np.minimum(na,nb)):
                p_nab_given_sa = hypergeom.pmf(nab,pool,np.arange(pool+1),nb)
                p_nab_given_s = np.dot(p_sa,p_nab_given_sa)
                p_gen[na,nb,nab] = p_nab_given_s * pna_pnb
    return p_gen

def p_shat_given_sc(s,c,shat,pool=60):
    masses = p_nab_given_c(s,c,pool=pool)
    if np.sum(masses)<0.99:
        print('Swapping to Monte Carlo')
        return p_shat_given_sc_montecarlo(s,c,shat,pool=pool)
    hist = binsta(np.ravel(shat),np.ravel(masses),statistic='sum',bins=(np.arange(pool+2)-0.5))
    return hist

def p_shat_given_sc_montecarlo(s,c,shat,pool=60,n_mc=int(1e5)):
    masses = np.zeros([pool+1,pool+1,pool+1])
    for ii in range(n_mc):
        nab,na,nb = pcr_sample(c,s)
        masses[na,nb,nab] += 1
    hist = binsta(np.ravel(shat),np.ravel(masses/n_mc),statistic='sum',bins=(np.arange(pool+2)-0.5))
    return hist

def compute_all_estimates(pool=60):
    shat = np.zeros([pool+1,pool+1,pool+1])
    for na in range(1,pool+1):
        for nb in range(1,pool+1):
            for nab in range(0,np.minimum(na+1,nb+1)):
                shat[na,nb,nab] = e_overlap(na,nb,nab,pool=pool)
    return shat

def p_overlap_unequal(na,nb,nab,pool_a,pool_b):
    # all loops are in terms of pool_a, which is assumed to be ≤ pool_b. 
    p_s = np.zeros(pool_a+1)
    # reference: hypergeom.pmf(outcome, Total, hits, Draws, loc=0)
    for s in np.arange(pool_a+1):
        # p_sa is the probability that we'd get sa from the overlap (s), just in na draws of a
        p_sa = hypergeom.pmf(np.arange(pool_a+1),pool_a,s,na)
        # p_nab_given_sa is the probability of getting that nab, given sa
        p_nab_given_sa = hypergeom.pmf(nab,pool_b,np.arange(pool_a+1),nb)
        p_s[s] = np.dot(p_sa,p_nab_given_sa)
    return p_s/np.sum(p_s)

def e_overlap_unequal(na,nb,nab,pool_a,pool_b):
    # TODO. Code expects that pool_b > pool_a...
    p_s = p_overlap_unequal(na,nb,nab,pool_a,pool_b)
    return np.dot(np.arange(pool_a+1),p_s)

# shat = compute_all_estimates(pool=60)
# np.save('shat_60.npy',shat)
shat = np.load('shat_60.npy')

reps=2
pool_a = 60
recovered = np.zeros([4,reps,pool_a+1])
planted = np.zeros([4,reps,pool_a+1])
recovered_pts = np.zeros([4,reps,pool_a+1])

# 1A
na,nb = 40,40
pool_b = 60
case = 0
for s in np.arange(0,pool_a+1,1):
    for rep in range(reps):
        nab = nab_sample_unequal(s,na,nb,pool_a,pool_b)
        planted[case,rep,s] = s
        recovered[case,rep,s] = e_overlap_unequal(na,nb,nab,pool_a,pool_b)
        recovered_pts[case,rep,s] = 2*nab/(na+nb)
# 1B
na,nb = 40,40
pool_b = 120
case = 1
for s in np.arange(0,pool_a+1,1):
    for rep in range(reps):
        nab = nab_sample_unequal(s,na,nb,pool_a,pool_b)
        planted[case,rep,s] = s
        recovered[case,rep,s] = e_overlap_unequal(na,nb,nab,pool_a,pool_b)
        recovered_pts[case,rep,s] = 2*nab/(na+nb)
# 2A
na,nb = 10,40
pool_b = 60
case = 2
for s in np.arange(0,pool_a+1,1):
    for rep in range(reps):
        nab = nab_sample_unequal(s,na,nb,pool_a,pool_b)
        planted[case,rep,s] = s
        recovered[case,rep,s] = e_overlap_unequal(na,nb,nab,pool_a,pool_b)
        recovered_pts[case,rep,s] = 2*nab/(na+nb)
# 2B
na,nb = 10,40
pool_b = 120
case = 3
for s in np.arange(0,pool_a+1,1):
    for rep in range(reps):
        nab = nab_sample_unequal(s,na,nb,pool_a,pool_b)
        planted[case,rep,s] = s
        recovered[case,rep,s] = e_overlap_unequal(na,nb,nab,pool_a,pool_b)
        recovered_pts[case,rep,s] = 2*nab/(na+nb)

d = np.arange(1,60)
pool = 60
planted2 = np.zeros((4,122))
recovered2 = np.zeros((4,122))
recovered_pts2 = np.zeros((4,122))
for index in range(0, 4):
    tmp11 = planted[index,:,:]
    tmp22 = recovered[index,:,:]
    tmp33 = recovered_pts[index,:,:]
    
    planted2[index] = tmp11.flatten()
    recovered2[index] = tmp22.flatten()
    recovered_pts2[index] = tmp33.flatten()

colors = ['rgba(129,105,177,0.7)', 'rgba(172,110,191,0.7)','rgba(215,128,193,0.7)', 'rgba(240,133,172,0.7)']  
fig6 = go.Figure()

fig6.add_trace(go.Scatter(x = d , y = d, line = dict(width=1, dash='dash'), name = 'Reference', showlegend = False ))


fig6.add_trace(go.Scatter(x=planted2[0], y=recovered2[0],
                    mode='markers',
                    visible = True,
                    name = 'BRO',
                  marker=dict(color=colors[0])))
fig6.add_trace(go.Scatter(x=planted2[0], y=pool*recovered_pts2[0],
                    mode='markers',
                    visible = True,
                    name = 'S',
                    marker_symbol = "x",
                    marker=dict(color='rgba(115, 115, 115, 0.7)')))


for case in np.arange(1,4):
    fig6.add_trace(go.Scatter(x=planted2[case], y=recovered2[case],
                    mode='markers',
                    visible = False,
                    name = 'BRO',
                  marker=dict(color=colors[case])))
    fig6.add_trace(go.Scatter(x=planted2[case], y=pool*recovered_pts2[case],
                    mode='markers',
                    visible = False,
                    name = 'S',
                    marker_symbol = "x",
                    marker=dict(color='rgba(115, 115, 115, 0.7)')))

    
    
fig6.update_layout(
    updatemenus=[
        dict(
            active=0,
            buttons=list([
        
                
                dict(label="a = 40x60 , b = 40x60",
                     method="update",
                     args=[{"visible": [True,True, True, False, False, False, False,False,False]},
                           {"title": "pop. a: 40 samples from total size 60,<br>pop. b: 40 samples from total size 60"}]),
                dict(label="a = 40x60 , b = 40x120",
                     method="update",
                     args=[{"visible": [True,False, False, True, True, False, False,False,False]},
                           {"title": "pop. a: 40 samples from total size 60,<br>pop. b: 40 samples from total size 120"
                            }]),
                dict(label="a = 10x60 , b = 40x60",
                     method="update",
                     args=[{"visible": [True,False, False,False,False, True, True,False,False]},
                           {"title": "pop. a: 10 samples from total size 60,<br>pop. b: 40 samples from total size 60"}]),
                dict(label="a = 10x60 , b = 40x120",
                     method="update",
                     args=[{"visible": [True,False, False,False,False,False,False,True,True]},
                           {"title": "pop. a: 10 samples from total size 60,<br>pop. b: 40 samples from total size 120"}]),
               
            ]),
                        x=0.4,
                        y = 1.0
 
        )
    ])
fig6.update_layout(plot_bgcolor='rgb(255,255,255)',xaxis_title="True overlap s", title = 'pop. a: 40 samples from total size 60,<br>pop. b: 40 samples from total size 60',
    yaxis_title='Estimated overlap s', legend=dict(x=.1, y=.65, bordercolor = 'Black',
                                                                             borderwidth = 2))

fig6.update_xaxes(ticks = 'outside', showline=True, linecolor='black')
fig6.update_yaxes(ticks = 'outside', showline=True, linecolor='black')    

# Plot figure
plot(fig6, filename = 'plotly_figures/fig6.html', config = config)
display(HTML('plotly_figures/fig6.html'))