Search
Perfusion AIF LV detection

Automated detection of LV from arterial input function (AIF) image series for cardiac MR perfusion with model deployment to MR scanner

Author: Hui Xue <hui.xue@nih.gov>

#import os
#os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
#os.environ['CUDA_VISIBLE_DEVICES']='1,2'

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torch.onnx
import torch.nn.functional as F
import torchvision.datasets as dset
import torchvision.transforms as T
from torchvision.utils import *

import numpy as np
import collections
import matplotlib.pyplot as plt
from matplotlib import animation, rc
animation.rcParams['animation.writer'] = 'ffmpeg'
plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'

import scipy
import scipy as sp
from scipy.spatial import ConvexHull
from scipy.ndimage.morphology import binary_fill_holes

from collections import OrderedDict
import time
from tensorboardX import SummaryWriter

from skimage import io, transform
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision

from IPython.display import display, clear_output, HTML, Image

from PIL import Image
import imp
import os
import sys
import math
import time
import random
import shutil
import scipy.misc
from glob import glob
import sklearn
import logging
from tqdm.notebook import tqdm

%matplotlib inline
%load_ext autoreload
%autoreload 2
import training
import models
import utils
import utils.cmr_ml_utils_plotting
def show(img):
    npimg = img.numpy()
    print(npimg.shape)
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
## Load image data
img_dir = './data'
import scipy.io

class PerfAIFDataset(Dataset):
    """Perfusion AIF dataset."""

    def __init__(self, img_dir, which_mask='LV_RV', min_reps=64, W=48, transform=None):
        """
        Args:
            img_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
            which_mask: LV or LV_RV
        """
        self.img_dir = img_dir
        self.transform = transform
        self.which_mask = which_mask
        self.min_reps = min_reps
        self.W = W
        
        # find all images
        locations = os.listdir(self.img_dir)
        a = []
        for loc in locations:
            if(os.path.isdir(os.path.join(self.img_dir, loc))):
                a.extend(os.listdir(os.path.join(self.img_dir, loc)))

        num_samples = len(a)
        print("Found %d cases ... " % num_samples)
        
        self.initialize_storage()

        t0 = time.time()
        print("Start loading cases ... ")
        
        total_num_loaded = 0
        
        for loc in locations:
            if(os.path.isdir(os.path.join(self.img_dir, loc))):
                total_num_loaded = self.load_one_loc(loc, total_num_loaded, t0)                                          
        
        print("Total samples loaded %d " % total_num_loaded)
        
    def initialize_storage(self):
        self.aif = []
        self.lv_rv_masks = []
        self.lv_masks = []
        self.names = []
    
    def load_one_data(self, loc, f_prefix):
        
        f_name = f_prefix + '.npy'
        data = np.load(os.path.join(self.img_dir, loc, f_name))                    
                       
        # print ('Loaded ', f_name, data.shape)
        
        return data
    
    def load_one_loc(self, loc, total_num_loaded, t0):      
        
        t1 = time.time()
        
        a = os.listdir(os.path.join(self.img_dir, loc))
        num_samples = len(a)
        
        print('---> Start loading ', loc)
        tq = tqdm(total=(num_samples), file=sys.stdout)
        
        for ii, n in enumerate(a):      
            
            '''
            if (ii>30):
                break
            '''
            
            name = os.path.join(loc, n)
            #print('------> Start loading %d out of %d, %s' % (ii, num_samples, name))

            tq.set_description('loading {}, total {}'.format(ii, num_samples))
            
            try:
                Gd = self.load_one_data(name, 'aif_scc')
                try:
                    lv_rv = self.load_one_data(name, 'aif_masks_final')                    
                except:
                    lv_rv = self.load_one_data(name, 'aif_masks')
            except:
                print('------> Failed to load %d out of %d, %s' % (ii, num_samples, n))
                continue

            RO, E1, N = Gd.shape
                
            if(N<self.min_reps):
                new_Gd = np.zeros((RO, E1, self.min_reps))
                new_Gd[:,:,0:N] = Gd
                f = Gd[:,:,N-1]
                new_Gd[:,:,N:self.min_reps] = np.dstack([f]*(self.min_reps-N))
                Gd = new_Gd
            
            if(N>self.min_reps):
                Gd = Gd[:,:,0:self.min_reps]
            
            if(E1>self.W):
                s = int((E1-self.W)/2)
                Gd = Gd[:,s:s+self.W,:]
                lv_rv = lv_rv[:,s:s+self.W]
                
            if(E1<self.W):
                s = int((self.W-E1)/2)
                new_Gd = np.zeros((RO, self.W, self.min_reps))
                new_Gd[:,s:s+E1,:] = Gd                
                Gd = new_Gd
                                
                new_lv_rv = np.zeros((RO, self.W))
                new_lv_rv[:,s:s+E1] = lv_rv
                lv_rv = new_lv_rv                
            
            RO, E1, N = Gd.shape
            if(RO!=64 or E1!=48 or N!=64):
                print('--> incorrect Gd shape : ', name)
                continue
            
            Gd = Gd / np.max(Gd)
            Gd = np.transpose(Gd, (2, 0, 1))
                
            lv = np.zeros_like(lv_rv)
            lv[np.where(lv_rv==1)] = 1
                
            lv_rv = np.reshape(lv_rv, (1, lv_rv.shape[0], lv_rv.shape[1]))
            lv = np.reshape(lv, (1, lv.shape[0], lv.shape[1]))
        
            self.aif.append(Gd.astype(np.float32))
            self.lv_rv_masks.append(lv_rv.astype(np.float32))
            self.lv_masks.append(lv.astype(np.float32))

            #print('     aif data : ', Gd.shape)
            #print('     lv_rv mask : ', lv_rv.shape)
            
            self.names.append(name)
                
            total_num_loaded += 1
                
            t1 = time.time()
            
            tq.update(1)
            tq.set_postfix(loss='{:.2f}s'.format(t1-t0))
                
            #print("             Time from starting : %f seconds ... \n" % (t1-t0))
           
        str_after_loading = '    Finish loading %s --- Total %d samples -- In %.2f seconds' % (loc, num_samples, t1-t0)
        tq.set_postfix_str(str_after_loading)
        
        tq.close() 
            
        return total_num_loaded
    
    def __len__(self):
        return len(self.aif)

    def __getitem__(self, idx):
        
        if idx >= len(self.aif):
            raise "invalid index"
        
        if (self.which_mask == 'lv_rv'):            
            sample = (self.aif[idx], self.lv_rv_masks[idx], self.names[idx])
            
        if (self.which_mask == 'lv'):            
            sample = (self.aif[idx], self.lv_masks[idx], self.names[idx])
            
        if self.transform:
            sample = self.transform(sample)

        return sample    
    
    def __str__(self):
        str = "Perfusion AIF Dataset\n"
        str += "  image root: %s" % self.img_dir + "\n"
        str += "  Number of samples: %d" % len(self.aif) + "\n"
        str += "  Number of masks: %d" % len(self.lv_rv_masks) + "\n"
        if len(self.aif) > 0:
            str += "  image shape: %d %d %d" % self.aif[0].shape + "\n"
            str += "  myo mask shape: %d %d %d" % self.lv_rv_masks[0].shape + "\n"
        return str
perf_aif_dataset = PerfAIFDataset(img_dir)
print("Done")
print(perf_aif_dataset)
perf_aif_dataset.which_mask = 'lv_rv'


B = len(perf_aif_dataset)

for n in range(B):
    Gd, masks, names = perf_aif_dataset[n]
    N, RO, E1 = Gd.shape
    if(RO!=64 or E1!=48 or N!=64):
        print('--> incorrect Gd shape : ', (n, names, Gd.shape))
    _, RO, E1 = masks.shape
    if(RO!=64 or E1!=48):
        print('--> incorrect masks shape : ', (n, names, masks.shape))

Check to see if data loaded properly

B = len(perf_aif_dataset)
print(B)
ni = np.zeros(B)
nm = np.zeros(B)

for n in np.arange(B):
    images, masks, names = perf_aif_dataset[n]
    ni[n] = np.linalg.norm(images)
    nm[n] = np.linalg.norm(masks)

    if(ni[n] < 1):
        print(names, ni[n])
    
#Original Plot from research paper    
plt.figure(figsize=(8,8))    
plt.plot(ni, nm, 'bo')
# Plotly
# x and y given as array_like objects
import plotly.graph_objects as go
import plotly
from IPython.core.display import display, HTML
import pandas as pd
import plotly.express as px
fig = px.scatter(x=ni, y=nm)

plotly.offline.plot(fig, filename = 'figure_1.html')
display(HTML('figure_1.html'))
print(perf_aif_dataset)
NUM_TRAIN = len(perf_aif_dataset)-1

batch_size = 2

loader_for_train = DataLoader(perf_aif_dataset, batch_size=batch_size, 
                          sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))

loader_for_val = DataLoader(perf_aif_dataset, batch_size=batch_size, 
                        sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, len(perf_aif_dataset))))
iter_train = iter(loader_for_train)
print(iter_train)
images, masks, names = iter_train.next()

print(images.shape)
print(masks.shape)

ia = np.transpose(images, (2, 3, 1, 0))
print(ia.shape)

#Original Plot from research paper
for n in range(masks.shape[0]):
     a = utils.cmr_ml_utils_plotting.plot_image_array(np.transpose(np.squeeze(images[n, 0:-1:4,:,:]), (1, 2, 0)), columns=16, figsize=[32,8])
    
a = utils.cmr_ml_utils_plotting.plot_image_array(np.transpose(np.squeeze(masks), (1, 2, 0)), columns=1, figsize=[16,16])

Images

#Plotly
import plotly.graph_objects as go
import numpy as np

# Create figure
fig = go.Figure()

# Add traces, one for each slider step
for step in range(15):
    fig.add_trace(
        go.Heatmap(z=np.squeeze(images[0, 0:-1:4,:,:])[step], colorscale="Gray")    
    )

# Make 10th trace visible
#fig.data[15].visible = True

# Create and add slider
steps = []
for i in range(len(fig.data)):
    step = dict(
        method="update",
        args=[{"visible": [False] * len(fig.data)},
              {"title": "Slider switched to image: " + str(i)}],  # layout attribute
    )
    step["args"][0]["visible"][i] = True  # Toggle i'th trace to "visible"
    steps.append(step)

sliders = [dict(
    active=0,
    currentvalue={"prefix": "Image: "},
    pad={"t": 50},
    steps=steps
)]

fig.update_layout(
    sliders=sliders
)

fig.update_layout(
    width=500,
    height=600,
    autosize=False,
    margin=dict(t=100, b=0, l=0, r=0)
)

fig.update_scenes(
    aspectratio=dict(x=1, y=1, z=0.7),
    aspectmode="manual"
)

# Add dropdowns
button_layer_1_height = 1.08
fig.update_layout(
    updatemenus=[
        dict(
            buttons=list([
                dict(
                    args=["colorscale", "Gray"],
                    label="Gray",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Cividis"],
                    label="Cividis",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Blues"],
                    label="Blues",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Viridis"],
                    label="Viridis",
                    method="restyle"
                ),
            ]),
            direction="down",
            pad={"r": 10, "t": -5},
            showactive=True,
            x=0.3,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        ),
        dict(
            buttons=list([
                dict(
                    args=["reversescale", False],
                    label="False",
                    method="restyle"
                ),
                dict(
                    args=["reversescale", True],
                    label="True",
                    method="restyle"
                )
            ]),
            direction="down",
            pad={"r": 10, "t": -5},
            showactive=True,
            x=0.8,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        ),
    ]
)

fig.update_layout(
    annotations=[
        dict(text="colorscale", x=0.1, xref="paper", y=1.06, yref="paper",
                             align="left", showarrow=False),
        dict(text="Reverse<br>Colorscale", x=0.8, xref="paper", y=1.1,
                             yref="paper", showarrow=False)
    ])



plotly.offline.plot(fig, filename = 'figure_2.html')
display(HTML('figure_2.html'))
#Plotly
import plotly.graph_objects as go
import numpy as np

# Create figure
fig = go.Figure()

# Add traces, one for each slider step
for step in range(15):
    fig.add_trace(
        go.Heatmap(z=np.squeeze(images[1, 0:-1:4,:,:])[step], colorscale="Gray")    
    )

# Make 10th trace visible
#fig.data[15].visible = True

# Create and add slider
steps = []
for i in range(len(fig.data)):
    step = dict(
        method="update",
        args=[{"visible": [False] * len(fig.data)},
              {"title": "Slider switched to image: " + str(i)}],  # layout attribute
    )
    step["args"][0]["visible"][i] = True  # Toggle i'th trace to "visible"
    steps.append(step)

sliders = [dict(
    active=0,
    currentvalue={"prefix": "Image: "},
    pad={"t": 50},
    steps=steps
)]

fig.update_layout(
    sliders=sliders
)

fig.update_layout(
    width=500,
    height=600,
    autosize=False,
    margin=dict(t=100, b=0, l=0, r=0)
)

fig.update_scenes(
    aspectratio=dict(x=1, y=1, z=0.7),
    aspectmode="manual"
)

# Add dropdowns
button_layer_1_height = 1.08
fig.update_layout(
    updatemenus=[
        dict(
            buttons=list([
                dict(
                    args=["colorscale", "Gray"],
                    label="Gray",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Cividis"],
                    label="Cividis",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Blues"],
                    label="Blues",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Viridis"],
                    label="Viridis",
                    method="restyle"
                ),
            ]),
            direction="down",
            pad={"r": 10, "t": -5},
            showactive=True,
            x=0.3,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        ),
        dict(
            buttons=list([
                dict(
                    args=["reversescale", False],
                    label="False",
                    method="restyle"
                ),
                dict(
                    args=["reversescale", True],
                    label="True",
                    method="restyle"
                )
            ]),
            direction="down",
            pad={"r": 10, "t": -5},
            showactive=True,
            x=0.8,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        ),
    ]
)

fig.update_layout(
    annotations=[
        dict(text="colorscale", x=0.1, xref="paper", y=1.06, yref="paper",
                             align="left", showarrow=False),
        dict(text="Reverse<br>Colorscale", x=0.8, xref="paper", y=1.1,
                             yref="paper", showarrow=False)
    ])



plotly.offline.plot(fig, filename = 'figure_3.html')
display(HTML('figure_3.html'))

Masks

# Plotly
import plotly.graph_objects as go

import pandas as pd

# load dataset
#df = pd.read_csv("https://raw.githubusercontent.com/plotly/datasets/master/volcano.csv")

# Create figure
fig = go.Figure()

# Add surface trace
fig.add_trace(
     go.Heatmap(z=np.squeeze(masks)[1],  name = 'Figure 1 - Mask', colorscale="Gray")
)

fig.add_trace(
     go.Heatmap(z=np.squeeze(masks)[0],  name = 'Figure 2 - Mask', colorscale="Gray")
)

# Update plot sizing
fig.update_layout(
    width=600,
    height=500,
    autosize=False,
    margin=dict(t=100, b=0, l=0, r=0),
)

# Update 3D scene options
fig.update_scenes(
    aspectratio=dict(x=1, y=1, z=0.7),
    aspectmode="manual"
)

# Add dropdowns
button_layer_1_height = 1.08
fig.update_layout(
    updatemenus=[
        dict(
            buttons=list([
                dict(
                    args=["colorscale", "Gray"],
                    label="Gray",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Cividis"],
                    label="Cividis",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Blues"],
                    label="Blues",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Viridis"],
                    label="Viridis",
                    method="restyle"
                ),
            ]),
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0.1,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        ),
        dict(
            buttons=list([
                dict(
                    args=["reversescale", False],
                    label="False",
                    method="restyle"
                ),
                dict(
                    args=["reversescale", True],
                    label="True",
                    method="restyle"
                )
            ]),
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0.37,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        ),
        dict(
            active=1,
        buttons=list(
        [
            dict(label = 'Figure 1 - Mask',
              method = 'update',
              args = [{'visible': [True, False]},
                      {'title': 'Figure 1 - Mask',
                       'showlegend':True}]),
          dict(label = 'Figure 2 - Mask',
              method = 'update',
              args = [{'visible': [False, True]},
                      {'title': 'Figure 2 - Mask',
                       'showlegend':True}]),
        ]
        ))
    ]
)

fig.update_layout(
    annotations=[
        dict(text="colorscale", x=-0.02, xref="paper", y=1.06, yref="paper",
                             align="left", showarrow=False),
        dict(text="Reverse<br>Colorscale", x=0.25, xref="paper", y=1.07,
                             yref="paper", showarrow=False)
    ])


plotly.offline.plot(fig, filename = 'figure_4.html')
display(HTML('figure_4.html'))
## Construct data loaders
USE_GPU = True

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print('using device:', device)

if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
## Data augmentation
class RandomFlip1stDim(object):
    """Randomly flip the first dimension of numpy array.
    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img):
        """
        Args:
            img ([N RO E1 ... ]): Image to be flipped.
        Returns:
            res: Randomly flipped image.
        """
        #print(img[0].shape)
        #print(img[1].shape)
            
        if random.random() < self.p: 
                                
            a = np.transpose(img[0], [1, 2, 0])
            a = np.flipud(a)
            a = np.transpose(a, [2, 0, 1])
            
            b = np.transpose(img[1], [1, 2, 0])
            b = np.flipud(b)
            b = np.transpose(b, [2, 0, 1])
            return ( a.copy(), b.copy(), img[2] )
        return img

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)
    
class RandomFlip2ndDim(object):
    """Randomly flip the second dimension of numpy array.
    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img):
        """
        Args:
            img ([N RO E1 ... ]): Image to be flipped.
        Returns:
            res: Randomly flipped image.
        """
        if random.random() < self.p:    
            a = np.transpose(img[0], [1, 2, 0])
            a = np.fliplr(a)
            a = np.transpose(a, [2, 0, 1])
            
            b = np.transpose(img[1], [1, 2, 0])
            b = np.fliplr(b)
            b = np.transpose(b, [2, 0, 1])
            return ( a.copy(), b.copy(), img[2] )
        return img

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import scipy
from scipy import ndimage

# probs should be a 64x48 torch tensor
def adaptive_thresh_cpu(probs, p_thresh=0.5, p_thresh_max=0.988):
    # Try regular adaptive thresholding first
    #p_thresh_max  = 0.988 # <-- Should not be too close to 1 to ensure while loop does not go over.

    p_thresh_incr = 0.01
    #p_thresh = 0.5

    RO = probs.shape[0]
    E1 = probs.shape[1]

    try:
        number_of_blobs = float("inf")
        blobs = np.zeros((RO,E1))
        while number_of_blobs > 1 and p_thresh < p_thresh_max:
            mask = (probs > torch.max(probs) * p_thresh).float()
            blobs, number_of_blobs = ndimage.label(mask)
            p_thresh += p_thresh_incr  # <-- Note this line can lead to float drift.
    
        if(number_of_blobs == 1):
            return mask

        if(number_of_blobs == 0):
            mask = np.zeros((RO, E1))
            print("adaptive_thresh_cpu, did not find any blobs ... ", file=sys.stderr)
            sys.stderr.flush()
            return mask

        ## If we are here then we cannot isolate a singular blob as the LV.
        ## Select the largest blob as the final mask.
        biggest_blob = (0, torch.zeros(RO,E1))
        for i in range(number_of_blobs):
            one_blob = torch.tensor((blobs == i+1).astype(int), dtype=torch.uint8)
            area = torch.sum(one_blob)
            if(area > biggest_blob[0]):
                biggest_blob = (area, one_blob)

        return biggest_blob[1]

    except Exception as e:
        print("Error happened in adaptive_thresh_cpu ...")
        print(e)
        sys.stderr.flush()
        mask = np.zeros((RO,E1))

    return mask
def compute_dice_scores(best_model, loader, aif_trainer, binary_seg=False):
    # get dice for all LV in validation set
    dice_scores = []
    cases = []
    best_model = best_model.cuda() 

    ind = 0

    best_model.eval()  # set model to evaluation mode
           
    for t, (x, y, names) in enumerate(loader):        

        x = x.to(aif_trainer.x_dtype).cuda() 
        y = y.to(aif_trainer.y_dtype).cuda()
       
        with torch.no_grad():
            scores = best_model(x)

        # loss = aif_trainer.compute_loss(scores, y)

        if(binary_seg):
            probs = F.sigmoid(scores)
        else:
            m = torch.nn.Softmax(dim=1)
            probs = m(scores)

        probs = probs.cpu().detach()

        N = x.shape[0]
        
        aif_mask = y.cpu().detach().numpy()        
            
        for n in range(N):
            
            if(binary_seg):
                lv_probs = probs[n, 0, :, :]
            else:
                lv_probs = probs[n, 1, :, :]
            # lv_mask = training.adaptive_thresh(lv_probs, device=torch.device('cpu'), p_thresh=0.5)
            #lv_mask = lv_mask.cpu().detach().numpy()
            
            #lv_probs = lv_probs.cpu().detach().numpy()
            #lv_probs = np.squeeze(lv_probs)
            
            lv_probs = lv_probs.cpu().detach()
            lv_mask = adaptive_thresh_cpu(lv_probs, p_thresh=0.5)
                        
            lv_mask = lv_mask.detach().numpy()
            lv_aif_mask = np.zeros(lv_mask.shape)
            lv_aif_mask[np.where(np.squeeze(aif_mask[n, 0, :,:]==1))] = 1

            ds = training.dice(lv_aif_mask, lv_mask)

            if(ds<0.1):
                print(names[n])
                curr_probs = probs[n, :, :, :].numpy()
                curr_probs = np.transpose(curr_probs, (2, 1, 0))
                a = utils.cmr_ml_utils_plotting.plot_image_array(np.squeeze(curr_probs), columns=8, figsize=[16,16])
                a = utils.cmr_ml_utils_plotting.plot_image_array(lv_aif_mask, columns=8, figsize=[16,16])
                a = utils.cmr_ml_utils_plotting.plot_image_array(lv_mask, columns=8, figsize=[16,16])
    
            dice_scores.append(ds)
            cases.append(names[n])
    
    return dice_scores, cases
def get_failed_cases(dice_scores, cases, thres=0.5, print_failed=True):
    total_samples = len(dice_scores)
    sucess_samples = 0
    failed_cases = []
    failed_dices = []
    for k in range(total_samples):
        if(dice_scores[k]>=thres):
            sucess_samples = sucess_samples + 1
        else:
            if(print_failed):
                print("case %s, dice %f " % (cases[k], dice_scores[k]))
                
            failed_cases.append(cases[k])
            failed_dices.append(dice_scores[k])

    success_rate = sucess_samples/total_samples
    print("Total test samples is ", total_samples)  
    print("Success rate is ", success_rate)  
    
    return failed_cases, failed_dices, success_rate
def load_apply_model_multi_class(img_dir, case_name, model, device):
    
    data_dir = os.path.join(img_dir, case_name)

    model_device = device
    
    print(data_dir)

    Gd = np.load(os.path.join(data_dir, 'aif.npy'))
    RO, E1, N = Gd.shape

    try:
        aif_mask = np.load(os.path.join(data_dir, 'aif_masks_final.npy'))
    except:
        aif_mask = np.load(os.path.join(data_dir, 'aif_masks.npy'))

    Gd = Gd[:,:,0:64]
    s = int((E1-48)/2)
    Gd = Gd[:,s:s+48,:]

    Gd = np.transpose(Gd, (2, 0, 1))
    Gd = np.reshape(Gd, (1, Gd.shape[0], Gd.shape[1], Gd.shape[2]))
    Gd /= np.max(Gd)

    aif = torch.from_numpy(Gd).float()
    aif = aif.to(device=model_device)
    model.eval() 
    with torch.no_grad():
        scores = model(aif)

    m = torch.nn.Softmax(dim=1)
    probs = m(scores)

    probs = probs.cpu().detach()
    
    lv_probs = probs[0, 1, :, :]
    lv_mask = training.adaptive_thresh(lv_probs, device=torch.device('cpu'), p_thresh=0.5)
    
    lv_mask = lv_mask.cpu().detach().numpy()
    probs = probs.cpu().detach().numpy()
    probs = np.squeeze(np.transpose(probs, (2, 3, 1, 0)))
    
    a = utils.cmr_ml_utils_plotting.plot_image_array(probs, columns=8, figsize=[16,16])
    a = utils.cmr_ml_utils_plotting.plot_image_array(aif_mask, columns=8, figsize=[16,16])
    a = utils.cmr_ml_utils_plotting.plot_image_array(lv_mask, columns=8, figsize=[16,16])
    
    return probs, lv_mask, aif_mask

Train with multi-class trainer

# data augmenation for random flipping
transform = torchvision.transforms.Compose([RandomFlip1stDim(0.5), RandomFlip2ndDim(0.5)])
perf_aif_dataset.transform = transform
perf_aif_dataset.which_mask = 'lv_rv'
num_classes = 3
class_for_accu = [1, 2]
class_weights = np.ones(num_classes)
class_weights[1] = 5
p_thres = [0.5, 0.5, 0.75]
print(perf_aif_dataset.which_mask)

sample = perf_aif_dataset[1]

print(sample[1].shape)

#Original Plot from research paper
plt.figure()
plt.imshow(np.squeeze(sample[1]))
#Plotly
#fig = px.imshow(np.squeeze(sample[1]), binary_string=True)

# load dataset
#df = pd.read_csv("https://raw.githubusercontent.com/plotly/datasets/master/volcano.csv")

# Create figure
fig = go.Figure()

# Add surface trace
fig.add_trace(go.Heatmap(z=np.squeeze(sample[1]), colorscale="Gray"))



# Update plot sizing
fig.update_layout(
    width=800,
    height=900,
    autosize=False,
    margin=dict(t=100, b=0, l=0, r=0),
)

# Update 3D scene options
fig.update_scenes(
    aspectratio=dict(x=1, y=1, z=0.7),
    aspectmode="manual"
)

# Add dropdowns
button_layer_1_height = 1.08
fig.update_layout(
    updatemenus=[
        dict(
            buttons=list([
                dict(
                    args=["colorscale", "Gray"],
                    label="Gray",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Cividis"],
                    label="Cividis",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Blues"],
                    label="Blues",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Viridis"],
                    label="Viridis",
                    method="restyle"
                ),
            ]),
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0.1,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        ),
        dict(
            buttons=list([
                dict(
                    args=["reversescale", False],
                    label="False",
                    method="restyle"
                ),
                dict(
                    args=["reversescale", True],
                    label="True",
                    method="restyle"
                )
            ]),
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0.37,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        ),
    ]
)

fig.update_layout(
    annotations=[
        dict(text="colorscale", x=0, xref="paper", y=1.06, yref="paper",
                             align="left", showarrow=False),
        dict(text="Reverse<br>Colorscale", x=0.25, xref="paper", y=1.07,
                             yref="paper", showarrow=False)
    ])


plotly.offline.plot(fig, filename = 'figure_5.html')
display(HTML('figure_5.html'))
images, masks, names = iter_train.next()

B, C, RO, E1 = images.shape

print(images.shape)
print(masks.shape)
print(torch.max(images))
print(torch.max(masks))

a = images[:,32,:,:]
a = torch.reshape(a, (B, 1, RO, E1))

plt.figure(figsize=(8, 8))
show(make_grid(a.double(), nrow=8, padding=2, normalize=False, scale_each=True))

plt.figure(figsize=(8, 8))
show(make_grid(masks.double(), nrow=8, padding=2, normalize=True, scale_each=False))

X = images.type(torch.FloatTensor)
y = masks.type(torch.FloatTensor)
print(X.shape)
print(y.shape)
def perform_training(hyperpara, perf_aif_dataset, loader_for_train, loader_for_val):
    
    num_epochs = hyperpara['num_epochs']
    print_every = 100000

    inplanes = hyperpara['inplanes']
    layers = hyperpara['layers']
    layers_planes = hyperpara['layers_planes']
    
    class_weights = hyperpara['class_weights']
    jaccard_weight = hyperpara['jaccard_weight']
    
    print('======================================================')
    print('num_epochs ', num_epochs)
    print('inplanes ', inplanes)
    print('layers ', layers)
    print('layers_planes ', layers_planes)
    print('class_weights ', class_weights)
    print('jaccard_weight ', jaccard_weight)
    print('======================================================')
    
    perf_aif_dataset.which_mask = 'lv_rv'
    num_classes = 3
    class_for_accu = [1, 2]
    p_thres = [0.5, 0.5, 0.75]

    print(perf_aif_dataset.aif[0].shape)
    C, H, W = perf_aif_dataset.aif[0].shape

    model = models.GadgetronResUnet18(F0=C, 
                              inplanes=inplanes, 
                              layers=layers, 
                              layers_planes=layers_planes, 
                              use_dropout=False, 
                              p=0.5, 
                              H=H, W=W, C=num_classes,
                              verbose=True)

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        print("model on multiple GPU ... ")

    patience = 10
    factor = 0.5
    cooldown = 3
    min_lr = 1e-7

    weight_decay=0
    learning_rate = 1e-3

    optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay, amsgrad=False)

    # optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay, nesterov=True)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience, cooldown=cooldown, min_lr=min_lr, verbose=True)

    CW = np.ones(num_classes)
    CW[1] = class_weights
    criterion = training.LossMulti(class_weights=CW, jaccard_weight=jaccard_weight)

    log_dir = 'aif_training/ResUnet' + '_lr_' + str(learning_rate) + '_epochs_' + str(num_epochs)
    writer = SummaryWriter(log_dir)
    
    aif_trainer = training.GadgetronMultiClassSeg_Perf(model, 
                                   optimizer, 
                                   criterion, 
                                   loader_for_train, 
                                   loader_for_val, 
                                   class_for_accu=class_for_accu,
                                   p_thres = p_thres,
                                   scheduler=scheduler, 
                                   epochs=num_epochs, 
                                   device=device, 
                                   x_dtype=torch.float32, 
                                   y_dtype=torch.long, 
                                   early_stopping_thres = 100,                              
                                   print_every=print_every, 
                                   writer=writer, 
                                   model_folder="perf_training/")
    
    
    epochs_traning, epochs_validation, best_model, loss_all, epochs_acc_class = aif_trainer.train(verbose=True, epoch_to_load=-1, save_model_epoch=True)
    
    dice_scores, cases = compute_dice_scores(best_model, loader_for_val, aif_trainer)
    failed_cases, failed_dices, success_rate = get_failed_cases(dice_scores, cases, thres=0.5, print_failed=False)   
    scipy.io.savemat(os.path.join(img_dir, 'perf_aif_lv_rv_val_failed.mat'), {"cases":failed_cases, "dices":failed_dices, "dice_scores":dice_scores, "cases":cases})
    
    dice_scores_train, cases_train = compute_dice_scores(best_model, loader_for_train, aif_trainer)
    failed_cases_train, failed_dices_train, success_rate_train = get_failed_cases(dice_scores_train, cases_train, thres=0.5, print_failed=False) 

    hyperpara['best_model'] = best_model
    hyperpara['epochs_traning'] = epochs_traning
    hyperpara['epochs_validation'] = epochs_validation
    hyperpara['loss_all'] = loss_all
    hyperpara['epochs_acc_class'] = epochs_acc_class
    
    hyperpara['dice_scores'] = dice_scores
    hyperpara['cases'] = cases
    hyperpara['failed_cases'] = failed_cases
    hyperpara['failed_dices'] = failed_dices
    hyperpara['success_rate'] = success_rate
    
    hyperpara['dice_scores_train'] = dice_scores
    hyperpara['cases_train'] = cases
    
    return hyperpara
# hyper parameter search
layers_planes = [[96, 128], [128, 128], [128, 256]]
layers = [[2, 3], [3, 3], [3, 4], [4, 4], [4, 5]]
inplanes = [64, 96, 128]

best_success_rate = 0
best_hyperpara = None

hyperpara_all = []

for a in range(len(layers_planes)):
    for b in range(len(layers)):
        for c in range(len(inplanes)):
            
            print('-----------------------------------------------')
            print(a, b, c)
            
            hyperpara = dict()
            
            hyperpara['num_epochs'] = 40
            hyperpara['inplanes'] = inplanes[c]
            hyperpara['layers'] = layers[b]
            hyperpara['layers_planes'] = layers_planes[a]
    
            hyperpara['class_weights'] = 5.0
            hyperpara['jaccard_weight'] = 0.5
            
            k = 12
            batch_size = 256

            # Chunk into k random sets
            chunks = chunk(range(len(perf_aif_dataset)), k)
            listified_chunks = list(chunks)

            val_idxs = listified_chunks[0]
            train_idxs = listified_chunks[1:]
            train_idxs = [item for sublist in train_idxs for item in sublist]

            num_train = len(train_idxs)
            num_val = len(val_idxs)

            loader_for_train = DataLoader(perf_aif_dataset, batch_size=batch_size, 
                                      sampler=sampler.SubsetRandomSampler(train_idxs))

            loader_for_val = DataLoader(perf_aif_dataset, batch_size=batch_size, 
                                    sampler=sampler.SubsetRandomSampler(val_idxs))

            num_train = len(train_idxs)
            print('num_train = %d' % num_train)
            num_val = len(val_idxs)
            print('num_val = %d' % num_val)
            
            hyperpara = perform_training(hyperpara, perf_aif_dataset, loader_for_train, loader_for_val)
            
            # print(hyperpara)
            print('success rate - ', hyperpara['success_rate'])
            
            if(hyperpara['success_rate']>best_success_rate):
                best_success_rate = hyperpara['success_rate']
                best_hyperpara = hyperpara
                
            hyperpara_all.append(hyperpara)
num_epochs = 50
print_every = 100000

# resnet
inplanes = 96
layers=[4, 4]
layers_planes=[128, 128]
growth_rate = 8

#dense net
inplanes = 64
layers=[3, 3]
layers_planes=[16, 32]
growth_rate = 16

# resnet, small
inplanes = 96
layers=[2, 3]
layers_planes=[128, 128]
growth_rate = 8

print(perf_aif_dataset.aif[0].shape)
C, H, W = perf_aif_dataset.aif[0].shape

model = models.GadgetronResUnet18(F0=C, 
                          inplanes=inplanes, 
                          layers=layers, 
                          layers_planes=layers_planes, 
                          use_dropout=False, 
                          p=0.5, 
                          H=H, W=W, C=num_classes,
                          verbose=True)

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
    print("model on multiple GPU ... ")

patience = 10
factor = 0.5
cooldown = 3
min_lr = 1e-7

weight_decay=0
learning_rate = 1e-3

optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay, amsgrad=False)

# optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay, nesterov=True)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience, cooldown=cooldown, min_lr=min_lr, verbose=True)

criterion = training.LossMulti(class_weights=class_weights, jaccard_weight=0.5)
# criterion = nn.BCEWithLogitsLoss()
# criterion = nn.BCELoss()

log_dir = './aif_training/ResUnet' + '_lr_' + str(learning_rate) + '_epochs_' + str(num_epochs)
writer = SummaryWriter(log_dir)
aif_trainer = training.GadgetronMultiClassSeg_Perf(model, 
                                   optimizer, 
                                   criterion, 
                                   loader_for_train, 
                                   loader_for_val, 
                                   class_for_accu=class_for_accu,
                                   p_thres = p_thres,
                                   scheduler=scheduler, 
                                   epochs=num_epochs, 
                                   device=device, 
                                   x_dtype=torch.float32, 
                                   y_dtype=torch.long, 
                                   early_stopping_thres = 100,                              
                                   print_every=print_every,
                                   small_data_mode = False, 
                                   writer=writer, 
                                   model_folder="aif_training/")
epochs_traning, epochs_validation, best_model, loss_all, epochs_acc_class = aif_trainer.train(verbose=True, epoch_to_load=-1, save_model_epoch=True)

Loss

#Original Plot from research paper
fig = plt.figure()
plt.plot(loss_all[0:500,0], loss_all[0:500,1])
#Plotly
fig = go.Figure(data=go.Scatter(x=loss_all[0:500,0], y=loss_all[0:500,1], name='Loss'))

plotly.offline.plot(fig, filename = 'figure_6.html')
display(HTML('figure_6.html'))
acc, loss, acc_class = aif_trainer.check_validation_test_accuracy(loader_for_val, best_model)
print(acc, loss)
print(acc_class)

Train with binary segmenation

# data augmenation for random flipping
transform = torchvision.transforms.Compose([RandomFlip1stDim(0.5), RandomFlip2ndDim(0.5)])
perf_aif_dataset.transform = transform
perf_aif_dataset.which_mask = 'lv'
num_classes = 1
p_thres = 0.5
print(perf_aif_dataset.which_mask)

sample = perf_aif_dataset[1]

print(sample[1].shape)

#Original Plot from research paper
plt.figure()
plt.imshow(np.squeeze(sample[1]))
#Plotly

#fig = px.imshow(np.squeeze(sample[1]), binary_string=True)

# load dataset
#df = pd.read_csv("https://raw.githubusercontent.com/plotly/datasets/master/volcano.csv")

# Create figure
fig = go.Figure()

# Add surface trace
fig.add_trace(go.Heatmap(z=np.squeeze(sample[1]), colorscale="Gray"))



# Update plot sizing
fig.update_layout(
    width=800,
    height=900,
    autosize=False,
    margin=dict(t=100, b=0, l=0, r=0),
)

# Update 3D scene options
fig.update_scenes(
    aspectratio=dict(x=1, y=1, z=0.7),
    aspectmode="manual"
)

# Add dropdowns
button_layer_1_height = 1.08
fig.update_layout(
    updatemenus=[
        dict(
            buttons=list([
                dict(
                    args=["colorscale", "Gray"],
                    label="Gray",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Cividis"],
                    label="Cividis",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Blues"],
                    label="Blues",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Viridis"],
                    label="Viridis",
                    method="restyle"
                ),
            ]),
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0.1,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        ),
        dict(
            buttons=list([
                dict(
                    args=["reversescale", False],
                    label="False",
                    method="restyle"
                ),
                dict(
                    args=["reversescale", True],
                    label="True",
                    method="restyle"
                )
            ]),
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0.37,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        ),
    ]
)

fig.update_layout(
    annotations=[
        dict(text="colorscale", x=0, xref="paper", y=1.06, yref="paper",
                             align="left", showarrow=False),
        dict(text="Reverse<br>Colorscale", x=0.25, xref="paper", y=1.07,
                             yref="paper", showarrow=False)
    ])


plotly.offline.plot(fig, filename = 'figure_7.html')
display(HTML('figure_7.html'))
images, masks, names = iter_train.next()

B, C, RO, E1 = images.shape

print(images.shape)
print(masks.shape)
print(torch.max(images))
print(torch.max(masks))

a = images[:,32,:,:]
a = torch.reshape(a, (B, 1, RO, E1))

plt.figure(figsize=(16, 16))
show(make_grid(a.double(), nrow=8, padding=2, normalize=False, scale_each=True))

plt.figure(figsize=(16, 16))
show(make_grid(masks.double(), nrow=8, padding=2, normalize=True, scale_each=False))

X = images.type(torch.FloatTensor)
y = masks.type(torch.FloatTensor)
print(X.shape)
print(y.shape)
num_epochs = 50
print_every = 100000

inplanes = 96
layers=[4, 4]
layers_planes=[128, 128]

print(perf_aif_dataset.aif[0].shape)
C, H, W = perf_aif_dataset.aif[0].shape

model = models.GadgetronResUnet18(F0=C, 
                          inplanes=inplanes, 
                          layers=layers, 
                          layers_planes=layers_planes, 
                          use_dropout=False, 
                          p=0.5, 
                          H=H, W=W, C=num_classes,
                          verbose=True)

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
    print("model on multiple GPU ... ")

patience = 10
factor = 0.5
cooldown = 3
min_lr = 1e-7

weight_decay=0
learning_rate = 1e-3

optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay, amsgrad=False)

# optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay, nesterov=True)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience, cooldown=cooldown, min_lr=min_lr, verbose=True)

criterion = training.LossBinary(jaccard_weight=0.5)
# criterion = training.LossMulti(class_weights=class_weights, jaccard_weight=0.5)

log_dir = 'aif_training/ResUnet' + '_lr_' + str(learning_rate) + '_epochs_' + str(num_epochs)
writer = SummaryWriter(log_dir)
aif_trainer = training.GadgetronTwoClassSeg_PerfAIF(model, 
                                   optimizer, 
                                   criterion, 
                                   loader_for_train, 
                                   loader_for_val, 
                                   p_thres=p_thres, 
                                   scheduler=scheduler, 
                                   epochs=num_epochs, 
                                   device=device, 
                                   x_dtype=torch.float32, 
                                   y_dtype=torch.float32, 
                                   early_stopping_thres = 100,                              
                                   print_every=print_every, 
                                   small_data_mode = False, 
                                   writer=writer, 
                                   model_folder="aif_training/")
epochs_traning, epochs_validation, best_model, loss_all, epochs_acc_class = aif_trainer.train(verbose=True, epoch_to_load=-1, save_model_epoch=True)

Saving the model

try:
    best_model_cpu = best_model.cpu().module
except:
    
    best_model_cpu = best_model.cpu()
    
print(best_model_cpu)
from datetime import date
today = str(date.today())
print(today)

from time import gmtime, strftime
moment = strftime("%Y%m%d_%H%M%S", gmtime())
print(moment)
model_file = './deployment/networks/perf_aif_' + perf_aif_dataset.which_mask + '_network_' + moment + '.pbt'
print(model_file)
./deployment/networks/perf_aif_lv_network_20210128_012515.pbt
torch.save(best_model_cpu, model_file)
model_loaded = torch.load(model_file)
print(model_loaded)