Perfusion AIF LV detection
#import os
#os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
#os.environ['CUDA_VISIBLE_DEVICES']='1,2'
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torch.onnx
import torch.nn.functional as F
import torchvision.datasets as dset
import torchvision.transforms as T
from torchvision.utils import *
import numpy as np
import collections
import matplotlib.pyplot as plt
from matplotlib import animation, rc
animation.rcParams['animation.writer'] = 'ffmpeg'
plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'
import scipy
import scipy as sp
from scipy.spatial import ConvexHull
from scipy.ndimage.morphology import binary_fill_holes
from collections import OrderedDict
import time
from tensorboardX import SummaryWriter
from skimage import io, transform
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision
from IPython.display import display, clear_output, HTML, Image
from PIL import Image
import imp
import os
import sys
import math
import time
import random
import shutil
import scipy.misc
from glob import glob
import sklearn
import logging
from tqdm.notebook import tqdm
%matplotlib inline
%load_ext autoreload
%autoreload 2
import training
import models
import utils
import utils.cmr_ml_utils_plotting
def show(img):
npimg = img.numpy()
print(npimg.shape)
plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
## Load image data
img_dir = './data'
import scipy.io
class PerfAIFDataset(Dataset):
"""Perfusion AIF dataset."""
def __init__(self, img_dir, which_mask='LV_RV', min_reps=64, W=48, transform=None):
"""
Args:
img_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied on a sample.
which_mask: LV or LV_RV
"""
self.img_dir = img_dir
self.transform = transform
self.which_mask = which_mask
self.min_reps = min_reps
self.W = W
# find all images
locations = os.listdir(self.img_dir)
a = []
for loc in locations:
if(os.path.isdir(os.path.join(self.img_dir, loc))):
a.extend(os.listdir(os.path.join(self.img_dir, loc)))
num_samples = len(a)
print("Found %d cases ... " % num_samples)
self.initialize_storage()
t0 = time.time()
print("Start loading cases ... ")
total_num_loaded = 0
for loc in locations:
if(os.path.isdir(os.path.join(self.img_dir, loc))):
total_num_loaded = self.load_one_loc(loc, total_num_loaded, t0)
print("Total samples loaded %d " % total_num_loaded)
def initialize_storage(self):
self.aif = []
self.lv_rv_masks = []
self.lv_masks = []
self.names = []
def load_one_data(self, loc, f_prefix):
f_name = f_prefix + '.npy'
data = np.load(os.path.join(self.img_dir, loc, f_name))
# print ('Loaded ', f_name, data.shape)
return data
def load_one_loc(self, loc, total_num_loaded, t0):
t1 = time.time()
a = os.listdir(os.path.join(self.img_dir, loc))
num_samples = len(a)
print('---> Start loading ', loc)
tq = tqdm(total=(num_samples), file=sys.stdout)
for ii, n in enumerate(a):
'''
if (ii>30):
break
'''
name = os.path.join(loc, n)
#print('------> Start loading %d out of %d, %s' % (ii, num_samples, name))
tq.set_description('loading {}, total {}'.format(ii, num_samples))
try:
Gd = self.load_one_data(name, 'aif_scc')
try:
lv_rv = self.load_one_data(name, 'aif_masks_final')
except:
lv_rv = self.load_one_data(name, 'aif_masks')
except:
print('------> Failed to load %d out of %d, %s' % (ii, num_samples, n))
continue
RO, E1, N = Gd.shape
if(N<self.min_reps):
new_Gd = np.zeros((RO, E1, self.min_reps))
new_Gd[:,:,0:N] = Gd
f = Gd[:,:,N-1]
new_Gd[:,:,N:self.min_reps] = np.dstack([f]*(self.min_reps-N))
Gd = new_Gd
if(N>self.min_reps):
Gd = Gd[:,:,0:self.min_reps]
if(E1>self.W):
s = int((E1-self.W)/2)
Gd = Gd[:,s:s+self.W,:]
lv_rv = lv_rv[:,s:s+self.W]
if(E1<self.W):
s = int((self.W-E1)/2)
new_Gd = np.zeros((RO, self.W, self.min_reps))
new_Gd[:,s:s+E1,:] = Gd
Gd = new_Gd
new_lv_rv = np.zeros((RO, self.W))
new_lv_rv[:,s:s+E1] = lv_rv
lv_rv = new_lv_rv
RO, E1, N = Gd.shape
if(RO!=64 or E1!=48 or N!=64):
print('--> incorrect Gd shape : ', name)
continue
Gd = Gd / np.max(Gd)
Gd = np.transpose(Gd, (2, 0, 1))
lv = np.zeros_like(lv_rv)
lv[np.where(lv_rv==1)] = 1
lv_rv = np.reshape(lv_rv, (1, lv_rv.shape[0], lv_rv.shape[1]))
lv = np.reshape(lv, (1, lv.shape[0], lv.shape[1]))
self.aif.append(Gd.astype(np.float32))
self.lv_rv_masks.append(lv_rv.astype(np.float32))
self.lv_masks.append(lv.astype(np.float32))
#print(' aif data : ', Gd.shape)
#print(' lv_rv mask : ', lv_rv.shape)
self.names.append(name)
total_num_loaded += 1
t1 = time.time()
tq.update(1)
tq.set_postfix(loss='{:.2f}s'.format(t1-t0))
#print(" Time from starting : %f seconds ... \n" % (t1-t0))
str_after_loading = ' Finish loading %s --- Total %d samples -- In %.2f seconds' % (loc, num_samples, t1-t0)
tq.set_postfix_str(str_after_loading)
tq.close()
return total_num_loaded
def __len__(self):
return len(self.aif)
def __getitem__(self, idx):
if idx >= len(self.aif):
raise "invalid index"
if (self.which_mask == 'lv_rv'):
sample = (self.aif[idx], self.lv_rv_masks[idx], self.names[idx])
if (self.which_mask == 'lv'):
sample = (self.aif[idx], self.lv_masks[idx], self.names[idx])
if self.transform:
sample = self.transform(sample)
return sample
def __str__(self):
str = "Perfusion AIF Dataset\n"
str += " image root: %s" % self.img_dir + "\n"
str += " Number of samples: %d" % len(self.aif) + "\n"
str += " Number of masks: %d" % len(self.lv_rv_masks) + "\n"
if len(self.aif) > 0:
str += " image shape: %d %d %d" % self.aif[0].shape + "\n"
str += " myo mask shape: %d %d %d" % self.lv_rv_masks[0].shape + "\n"
return str
perf_aif_dataset = PerfAIFDataset(img_dir)
print("Done")
print(perf_aif_dataset)
perf_aif_dataset.which_mask = 'lv_rv'
B = len(perf_aif_dataset)
for n in range(B):
Gd, masks, names = perf_aif_dataset[n]
N, RO, E1 = Gd.shape
if(RO!=64 or E1!=48 or N!=64):
print('--> incorrect Gd shape : ', (n, names, Gd.shape))
_, RO, E1 = masks.shape
if(RO!=64 or E1!=48):
print('--> incorrect masks shape : ', (n, names, masks.shape))
B = len(perf_aif_dataset)
print(B)
ni = np.zeros(B)
nm = np.zeros(B)
for n in np.arange(B):
images, masks, names = perf_aif_dataset[n]
ni[n] = np.linalg.norm(images)
nm[n] = np.linalg.norm(masks)
if(ni[n] < 1):
print(names, ni[n])
#Original Plot from research paper
plt.figure(figsize=(8,8))
plt.plot(ni, nm, 'bo')
# Plotly
# x and y given as array_like objects
import plotly.graph_objects as go
import plotly
from IPython.core.display import display, HTML
import pandas as pd
import plotly.express as px
fig = px.scatter(x=ni, y=nm)
plotly.offline.plot(fig, filename = 'figure_1.html')
display(HTML('figure_1.html'))
print(perf_aif_dataset)
NUM_TRAIN = len(perf_aif_dataset)-1
batch_size = 2
loader_for_train = DataLoader(perf_aif_dataset, batch_size=batch_size,
sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))
loader_for_val = DataLoader(perf_aif_dataset, batch_size=batch_size,
sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, len(perf_aif_dataset))))
iter_train = iter(loader_for_train)
print(iter_train)
images, masks, names = iter_train.next()
print(images.shape)
print(masks.shape)
ia = np.transpose(images, (2, 3, 1, 0))
print(ia.shape)
#Original Plot from research paper
for n in range(masks.shape[0]):
a = utils.cmr_ml_utils_plotting.plot_image_array(np.transpose(np.squeeze(images[n, 0:-1:4,:,:]), (1, 2, 0)), columns=16, figsize=[32,8])
a = utils.cmr_ml_utils_plotting.plot_image_array(np.transpose(np.squeeze(masks), (1, 2, 0)), columns=1, figsize=[16,16])
#Plotly
import plotly.graph_objects as go
import numpy as np
# Create figure
fig = go.Figure()
# Add traces, one for each slider step
for step in range(15):
fig.add_trace(
go.Heatmap(z=np.squeeze(images[0, 0:-1:4,:,:])[step], colorscale="Gray")
)
# Make 10th trace visible
#fig.data[15].visible = True
# Create and add slider
steps = []
for i in range(len(fig.data)):
step = dict(
method="update",
args=[{"visible": [False] * len(fig.data)},
{"title": "Slider switched to image: " + str(i)}], # layout attribute
)
step["args"][0]["visible"][i] = True # Toggle i'th trace to "visible"
steps.append(step)
sliders = [dict(
active=0,
currentvalue={"prefix": "Image: "},
pad={"t": 50},
steps=steps
)]
fig.update_layout(
sliders=sliders
)
fig.update_layout(
width=500,
height=600,
autosize=False,
margin=dict(t=100, b=0, l=0, r=0)
)
fig.update_scenes(
aspectratio=dict(x=1, y=1, z=0.7),
aspectmode="manual"
)
# Add dropdowns
button_layer_1_height = 1.08
fig.update_layout(
updatemenus=[
dict(
buttons=list([
dict(
args=["colorscale", "Gray"],
label="Gray",
method="restyle"
),
dict(
args=["colorscale", "Cividis"],
label="Cividis",
method="restyle"
),
dict(
args=["colorscale", "Blues"],
label="Blues",
method="restyle"
),
dict(
args=["colorscale", "Viridis"],
label="Viridis",
method="restyle"
),
]),
direction="down",
pad={"r": 10, "t": -5},
showactive=True,
x=0.3,
xanchor="left",
y=button_layer_1_height,
yanchor="top"
),
dict(
buttons=list([
dict(
args=["reversescale", False],
label="False",
method="restyle"
),
dict(
args=["reversescale", True],
label="True",
method="restyle"
)
]),
direction="down",
pad={"r": 10, "t": -5},
showactive=True,
x=0.8,
xanchor="left",
y=button_layer_1_height,
yanchor="top"
),
]
)
fig.update_layout(
annotations=[
dict(text="colorscale", x=0.1, xref="paper", y=1.06, yref="paper",
align="left", showarrow=False),
dict(text="Reverse<br>Colorscale", x=0.8, xref="paper", y=1.1,
yref="paper", showarrow=False)
])
plotly.offline.plot(fig, filename = 'figure_2.html')
display(HTML('figure_2.html'))
#Plotly
import plotly.graph_objects as go
import numpy as np
# Create figure
fig = go.Figure()
# Add traces, one for each slider step
for step in range(15):
fig.add_trace(
go.Heatmap(z=np.squeeze(images[1, 0:-1:4,:,:])[step], colorscale="Gray")
)
# Make 10th trace visible
#fig.data[15].visible = True
# Create and add slider
steps = []
for i in range(len(fig.data)):
step = dict(
method="update",
args=[{"visible": [False] * len(fig.data)},
{"title": "Slider switched to image: " + str(i)}], # layout attribute
)
step["args"][0]["visible"][i] = True # Toggle i'th trace to "visible"
steps.append(step)
sliders = [dict(
active=0,
currentvalue={"prefix": "Image: "},
pad={"t": 50},
steps=steps
)]
fig.update_layout(
sliders=sliders
)
fig.update_layout(
width=500,
height=600,
autosize=False,
margin=dict(t=100, b=0, l=0, r=0)
)
fig.update_scenes(
aspectratio=dict(x=1, y=1, z=0.7),
aspectmode="manual"
)
# Add dropdowns
button_layer_1_height = 1.08
fig.update_layout(
updatemenus=[
dict(
buttons=list([
dict(
args=["colorscale", "Gray"],
label="Gray",
method="restyle"
),
dict(
args=["colorscale", "Cividis"],
label="Cividis",
method="restyle"
),
dict(
args=["colorscale", "Blues"],
label="Blues",
method="restyle"
),
dict(
args=["colorscale", "Viridis"],
label="Viridis",
method="restyle"
),
]),
direction="down",
pad={"r": 10, "t": -5},
showactive=True,
x=0.3,
xanchor="left",
y=button_layer_1_height,
yanchor="top"
),
dict(
buttons=list([
dict(
args=["reversescale", False],
label="False",
method="restyle"
),
dict(
args=["reversescale", True],
label="True",
method="restyle"
)
]),
direction="down",
pad={"r": 10, "t": -5},
showactive=True,
x=0.8,
xanchor="left",
y=button_layer_1_height,
yanchor="top"
),
]
)
fig.update_layout(
annotations=[
dict(text="colorscale", x=0.1, xref="paper", y=1.06, yref="paper",
align="left", showarrow=False),
dict(text="Reverse<br>Colorscale", x=0.8, xref="paper", y=1.1,
yref="paper", showarrow=False)
])
plotly.offline.plot(fig, filename = 'figure_3.html')
display(HTML('figure_3.html'))
# Plotly
import plotly.graph_objects as go
import pandas as pd
# load dataset
#df = pd.read_csv("https://raw.githubusercontent.com/plotly/datasets/master/volcano.csv")
# Create figure
fig = go.Figure()
# Add surface trace
fig.add_trace(
go.Heatmap(z=np.squeeze(masks)[1], name = 'Figure 1 - Mask', colorscale="Gray")
)
fig.add_trace(
go.Heatmap(z=np.squeeze(masks)[0], name = 'Figure 2 - Mask', colorscale="Gray")
)
# Update plot sizing
fig.update_layout(
width=600,
height=500,
autosize=False,
margin=dict(t=100, b=0, l=0, r=0),
)
# Update 3D scene options
fig.update_scenes(
aspectratio=dict(x=1, y=1, z=0.7),
aspectmode="manual"
)
# Add dropdowns
button_layer_1_height = 1.08
fig.update_layout(
updatemenus=[
dict(
buttons=list([
dict(
args=["colorscale", "Gray"],
label="Gray",
method="restyle"
),
dict(
args=["colorscale", "Cividis"],
label="Cividis",
method="restyle"
),
dict(
args=["colorscale", "Blues"],
label="Blues",
method="restyle"
),
dict(
args=["colorscale", "Viridis"],
label="Viridis",
method="restyle"
),
]),
direction="down",
pad={"r": 10, "t": 10},
showactive=True,
x=0.1,
xanchor="left",
y=button_layer_1_height,
yanchor="top"
),
dict(
buttons=list([
dict(
args=["reversescale", False],
label="False",
method="restyle"
),
dict(
args=["reversescale", True],
label="True",
method="restyle"
)
]),
direction="down",
pad={"r": 10, "t": 10},
showactive=True,
x=0.37,
xanchor="left",
y=button_layer_1_height,
yanchor="top"
),
dict(
active=1,
buttons=list(
[
dict(label = 'Figure 1 - Mask',
method = 'update',
args = [{'visible': [True, False]},
{'title': 'Figure 1 - Mask',
'showlegend':True}]),
dict(label = 'Figure 2 - Mask',
method = 'update',
args = [{'visible': [False, True]},
{'title': 'Figure 2 - Mask',
'showlegend':True}]),
]
))
]
)
fig.update_layout(
annotations=[
dict(text="colorscale", x=-0.02, xref="paper", y=1.06, yref="paper",
align="left", showarrow=False),
dict(text="Reverse<br>Colorscale", x=0.25, xref="paper", y=1.07,
yref="paper", showarrow=False)
])
plotly.offline.plot(fig, filename = 'figure_4.html')
display(HTML('figure_4.html'))
## Construct data loaders
USE_GPU = True
if USE_GPU and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
print('using device:', device)
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
## Data augmentation
class RandomFlip1stDim(object):
"""Randomly flip the first dimension of numpy array.
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, img):
"""
Args:
img ([N RO E1 ... ]): Image to be flipped.
Returns:
res: Randomly flipped image.
"""
#print(img[0].shape)
#print(img[1].shape)
if random.random() < self.p:
a = np.transpose(img[0], [1, 2, 0])
a = np.flipud(a)
a = np.transpose(a, [2, 0, 1])
b = np.transpose(img[1], [1, 2, 0])
b = np.flipud(b)
b = np.transpose(b, [2, 0, 1])
return ( a.copy(), b.copy(), img[2] )
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
class RandomFlip2ndDim(object):
"""Randomly flip the second dimension of numpy array.
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, img):
"""
Args:
img ([N RO E1 ... ]): Image to be flipped.
Returns:
res: Randomly flipped image.
"""
if random.random() < self.p:
a = np.transpose(img[0], [1, 2, 0])
a = np.fliplr(a)
a = np.transpose(a, [2, 0, 1])
b = np.transpose(img[1], [1, 2, 0])
b = np.fliplr(b)
b = np.transpose(b, [2, 0, 1])
return ( a.copy(), b.copy(), img[2] )
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import scipy
from scipy import ndimage
# probs should be a 64x48 torch tensor
def adaptive_thresh_cpu(probs, p_thresh=0.5, p_thresh_max=0.988):
# Try regular adaptive thresholding first
#p_thresh_max = 0.988 # <-- Should not be too close to 1 to ensure while loop does not go over.
p_thresh_incr = 0.01
#p_thresh = 0.5
RO = probs.shape[0]
E1 = probs.shape[1]
try:
number_of_blobs = float("inf")
blobs = np.zeros((RO,E1))
while number_of_blobs > 1 and p_thresh < p_thresh_max:
mask = (probs > torch.max(probs) * p_thresh).float()
blobs, number_of_blobs = ndimage.label(mask)
p_thresh += p_thresh_incr # <-- Note this line can lead to float drift.
if(number_of_blobs == 1):
return mask
if(number_of_blobs == 0):
mask = np.zeros((RO, E1))
print("adaptive_thresh_cpu, did not find any blobs ... ", file=sys.stderr)
sys.stderr.flush()
return mask
## If we are here then we cannot isolate a singular blob as the LV.
## Select the largest blob as the final mask.
biggest_blob = (0, torch.zeros(RO,E1))
for i in range(number_of_blobs):
one_blob = torch.tensor((blobs == i+1).astype(int), dtype=torch.uint8)
area = torch.sum(one_blob)
if(area > biggest_blob[0]):
biggest_blob = (area, one_blob)
return biggest_blob[1]
except Exception as e:
print("Error happened in adaptive_thresh_cpu ...")
print(e)
sys.stderr.flush()
mask = np.zeros((RO,E1))
return mask
def compute_dice_scores(best_model, loader, aif_trainer, binary_seg=False):
# get dice for all LV in validation set
dice_scores = []
cases = []
best_model = best_model.cuda()
ind = 0
best_model.eval() # set model to evaluation mode
for t, (x, y, names) in enumerate(loader):
x = x.to(aif_trainer.x_dtype).cuda()
y = y.to(aif_trainer.y_dtype).cuda()
with torch.no_grad():
scores = best_model(x)
# loss = aif_trainer.compute_loss(scores, y)
if(binary_seg):
probs = F.sigmoid(scores)
else:
m = torch.nn.Softmax(dim=1)
probs = m(scores)
probs = probs.cpu().detach()
N = x.shape[0]
aif_mask = y.cpu().detach().numpy()
for n in range(N):
if(binary_seg):
lv_probs = probs[n, 0, :, :]
else:
lv_probs = probs[n, 1, :, :]
# lv_mask = training.adaptive_thresh(lv_probs, device=torch.device('cpu'), p_thresh=0.5)
#lv_mask = lv_mask.cpu().detach().numpy()
#lv_probs = lv_probs.cpu().detach().numpy()
#lv_probs = np.squeeze(lv_probs)
lv_probs = lv_probs.cpu().detach()
lv_mask = adaptive_thresh_cpu(lv_probs, p_thresh=0.5)
lv_mask = lv_mask.detach().numpy()
lv_aif_mask = np.zeros(lv_mask.shape)
lv_aif_mask[np.where(np.squeeze(aif_mask[n, 0, :,:]==1))] = 1
ds = training.dice(lv_aif_mask, lv_mask)
if(ds<0.1):
print(names[n])
curr_probs = probs[n, :, :, :].numpy()
curr_probs = np.transpose(curr_probs, (2, 1, 0))
a = utils.cmr_ml_utils_plotting.plot_image_array(np.squeeze(curr_probs), columns=8, figsize=[16,16])
a = utils.cmr_ml_utils_plotting.plot_image_array(lv_aif_mask, columns=8, figsize=[16,16])
a = utils.cmr_ml_utils_plotting.plot_image_array(lv_mask, columns=8, figsize=[16,16])
dice_scores.append(ds)
cases.append(names[n])
return dice_scores, cases
def get_failed_cases(dice_scores, cases, thres=0.5, print_failed=True):
total_samples = len(dice_scores)
sucess_samples = 0
failed_cases = []
failed_dices = []
for k in range(total_samples):
if(dice_scores[k]>=thres):
sucess_samples = sucess_samples + 1
else:
if(print_failed):
print("case %s, dice %f " % (cases[k], dice_scores[k]))
failed_cases.append(cases[k])
failed_dices.append(dice_scores[k])
success_rate = sucess_samples/total_samples
print("Total test samples is ", total_samples)
print("Success rate is ", success_rate)
return failed_cases, failed_dices, success_rate
def load_apply_model_multi_class(img_dir, case_name, model, device):
data_dir = os.path.join(img_dir, case_name)
model_device = device
print(data_dir)
Gd = np.load(os.path.join(data_dir, 'aif.npy'))
RO, E1, N = Gd.shape
try:
aif_mask = np.load(os.path.join(data_dir, 'aif_masks_final.npy'))
except:
aif_mask = np.load(os.path.join(data_dir, 'aif_masks.npy'))
Gd = Gd[:,:,0:64]
s = int((E1-48)/2)
Gd = Gd[:,s:s+48,:]
Gd = np.transpose(Gd, (2, 0, 1))
Gd = np.reshape(Gd, (1, Gd.shape[0], Gd.shape[1], Gd.shape[2]))
Gd /= np.max(Gd)
aif = torch.from_numpy(Gd).float()
aif = aif.to(device=model_device)
model.eval()
with torch.no_grad():
scores = model(aif)
m = torch.nn.Softmax(dim=1)
probs = m(scores)
probs = probs.cpu().detach()
lv_probs = probs[0, 1, :, :]
lv_mask = training.adaptive_thresh(lv_probs, device=torch.device('cpu'), p_thresh=0.5)
lv_mask = lv_mask.cpu().detach().numpy()
probs = probs.cpu().detach().numpy()
probs = np.squeeze(np.transpose(probs, (2, 3, 1, 0)))
a = utils.cmr_ml_utils_plotting.plot_image_array(probs, columns=8, figsize=[16,16])
a = utils.cmr_ml_utils_plotting.plot_image_array(aif_mask, columns=8, figsize=[16,16])
a = utils.cmr_ml_utils_plotting.plot_image_array(lv_mask, columns=8, figsize=[16,16])
return probs, lv_mask, aif_mask
# data augmenation for random flipping
transform = torchvision.transforms.Compose([RandomFlip1stDim(0.5), RandomFlip2ndDim(0.5)])
perf_aif_dataset.transform = transform
perf_aif_dataset.which_mask = 'lv_rv'
num_classes = 3
class_for_accu = [1, 2]
class_weights = np.ones(num_classes)
class_weights[1] = 5
p_thres = [0.5, 0.5, 0.75]
print(perf_aif_dataset.which_mask)
sample = perf_aif_dataset[1]
print(sample[1].shape)
#Original Plot from research paper
plt.figure()
plt.imshow(np.squeeze(sample[1]))
#Plotly
#fig = px.imshow(np.squeeze(sample[1]), binary_string=True)
# load dataset
#df = pd.read_csv("https://raw.githubusercontent.com/plotly/datasets/master/volcano.csv")
# Create figure
fig = go.Figure()
# Add surface trace
fig.add_trace(go.Heatmap(z=np.squeeze(sample[1]), colorscale="Gray"))
# Update plot sizing
fig.update_layout(
width=800,
height=900,
autosize=False,
margin=dict(t=100, b=0, l=0, r=0),
)
# Update 3D scene options
fig.update_scenes(
aspectratio=dict(x=1, y=1, z=0.7),
aspectmode="manual"
)
# Add dropdowns
button_layer_1_height = 1.08
fig.update_layout(
updatemenus=[
dict(
buttons=list([
dict(
args=["colorscale", "Gray"],
label="Gray",
method="restyle"
),
dict(
args=["colorscale", "Cividis"],
label="Cividis",
method="restyle"
),
dict(
args=["colorscale", "Blues"],
label="Blues",
method="restyle"
),
dict(
args=["colorscale", "Viridis"],
label="Viridis",
method="restyle"
),
]),
direction="down",
pad={"r": 10, "t": 10},
showactive=True,
x=0.1,
xanchor="left",
y=button_layer_1_height,
yanchor="top"
),
dict(
buttons=list([
dict(
args=["reversescale", False],
label="False",
method="restyle"
),
dict(
args=["reversescale", True],
label="True",
method="restyle"
)
]),
direction="down",
pad={"r": 10, "t": 10},
showactive=True,
x=0.37,
xanchor="left",
y=button_layer_1_height,
yanchor="top"
),
]
)
fig.update_layout(
annotations=[
dict(text="colorscale", x=0, xref="paper", y=1.06, yref="paper",
align="left", showarrow=False),
dict(text="Reverse<br>Colorscale", x=0.25, xref="paper", y=1.07,
yref="paper", showarrow=False)
])
plotly.offline.plot(fig, filename = 'figure_5.html')
display(HTML('figure_5.html'))
images, masks, names = iter_train.next()
B, C, RO, E1 = images.shape
print(images.shape)
print(masks.shape)
print(torch.max(images))
print(torch.max(masks))
a = images[:,32,:,:]
a = torch.reshape(a, (B, 1, RO, E1))
plt.figure(figsize=(8, 8))
show(make_grid(a.double(), nrow=8, padding=2, normalize=False, scale_each=True))
plt.figure(figsize=(8, 8))
show(make_grid(masks.double(), nrow=8, padding=2, normalize=True, scale_each=False))
X = images.type(torch.FloatTensor)
y = masks.type(torch.FloatTensor)
print(X.shape)
print(y.shape)
def perform_training(hyperpara, perf_aif_dataset, loader_for_train, loader_for_val):
num_epochs = hyperpara['num_epochs']
print_every = 100000
inplanes = hyperpara['inplanes']
layers = hyperpara['layers']
layers_planes = hyperpara['layers_planes']
class_weights = hyperpara['class_weights']
jaccard_weight = hyperpara['jaccard_weight']
print('======================================================')
print('num_epochs ', num_epochs)
print('inplanes ', inplanes)
print('layers ', layers)
print('layers_planes ', layers_planes)
print('class_weights ', class_weights)
print('jaccard_weight ', jaccard_weight)
print('======================================================')
perf_aif_dataset.which_mask = 'lv_rv'
num_classes = 3
class_for_accu = [1, 2]
p_thres = [0.5, 0.5, 0.75]
print(perf_aif_dataset.aif[0].shape)
C, H, W = perf_aif_dataset.aif[0].shape
model = models.GadgetronResUnet18(F0=C,
inplanes=inplanes,
layers=layers,
layers_planes=layers_planes,
use_dropout=False,
p=0.5,
H=H, W=W, C=num_classes,
verbose=True)
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
print("model on multiple GPU ... ")
patience = 10
factor = 0.5
cooldown = 3
min_lr = 1e-7
weight_decay=0
learning_rate = 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay, amsgrad=False)
# optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay, nesterov=True)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience, cooldown=cooldown, min_lr=min_lr, verbose=True)
CW = np.ones(num_classes)
CW[1] = class_weights
criterion = training.LossMulti(class_weights=CW, jaccard_weight=jaccard_weight)
log_dir = 'aif_training/ResUnet' + '_lr_' + str(learning_rate) + '_epochs_' + str(num_epochs)
writer = SummaryWriter(log_dir)
aif_trainer = training.GadgetronMultiClassSeg_Perf(model,
optimizer,
criterion,
loader_for_train,
loader_for_val,
class_for_accu=class_for_accu,
p_thres = p_thres,
scheduler=scheduler,
epochs=num_epochs,
device=device,
x_dtype=torch.float32,
y_dtype=torch.long,
early_stopping_thres = 100,
print_every=print_every,
writer=writer,
model_folder="perf_training/")
epochs_traning, epochs_validation, best_model, loss_all, epochs_acc_class = aif_trainer.train(verbose=True, epoch_to_load=-1, save_model_epoch=True)
dice_scores, cases = compute_dice_scores(best_model, loader_for_val, aif_trainer)
failed_cases, failed_dices, success_rate = get_failed_cases(dice_scores, cases, thres=0.5, print_failed=False)
scipy.io.savemat(os.path.join(img_dir, 'perf_aif_lv_rv_val_failed.mat'), {"cases":failed_cases, "dices":failed_dices, "dice_scores":dice_scores, "cases":cases})
dice_scores_train, cases_train = compute_dice_scores(best_model, loader_for_train, aif_trainer)
failed_cases_train, failed_dices_train, success_rate_train = get_failed_cases(dice_scores_train, cases_train, thres=0.5, print_failed=False)
hyperpara['best_model'] = best_model
hyperpara['epochs_traning'] = epochs_traning
hyperpara['epochs_validation'] = epochs_validation
hyperpara['loss_all'] = loss_all
hyperpara['epochs_acc_class'] = epochs_acc_class
hyperpara['dice_scores'] = dice_scores
hyperpara['cases'] = cases
hyperpara['failed_cases'] = failed_cases
hyperpara['failed_dices'] = failed_dices
hyperpara['success_rate'] = success_rate
hyperpara['dice_scores_train'] = dice_scores
hyperpara['cases_train'] = cases
return hyperpara
# hyper parameter search
layers_planes = [[96, 128], [128, 128], [128, 256]]
layers = [[2, 3], [3, 3], [3, 4], [4, 4], [4, 5]]
inplanes = [64, 96, 128]
best_success_rate = 0
best_hyperpara = None
hyperpara_all = []
for a in range(len(layers_planes)):
for b in range(len(layers)):
for c in range(len(inplanes)):
print('-----------------------------------------------')
print(a, b, c)
hyperpara = dict()
hyperpara['num_epochs'] = 40
hyperpara['inplanes'] = inplanes[c]
hyperpara['layers'] = layers[b]
hyperpara['layers_planes'] = layers_planes[a]
hyperpara['class_weights'] = 5.0
hyperpara['jaccard_weight'] = 0.5
k = 12
batch_size = 256
# Chunk into k random sets
chunks = chunk(range(len(perf_aif_dataset)), k)
listified_chunks = list(chunks)
val_idxs = listified_chunks[0]
train_idxs = listified_chunks[1:]
train_idxs = [item for sublist in train_idxs for item in sublist]
num_train = len(train_idxs)
num_val = len(val_idxs)
loader_for_train = DataLoader(perf_aif_dataset, batch_size=batch_size,
sampler=sampler.SubsetRandomSampler(train_idxs))
loader_for_val = DataLoader(perf_aif_dataset, batch_size=batch_size,
sampler=sampler.SubsetRandomSampler(val_idxs))
num_train = len(train_idxs)
print('num_train = %d' % num_train)
num_val = len(val_idxs)
print('num_val = %d' % num_val)
hyperpara = perform_training(hyperpara, perf_aif_dataset, loader_for_train, loader_for_val)
# print(hyperpara)
print('success rate - ', hyperpara['success_rate'])
if(hyperpara['success_rate']>best_success_rate):
best_success_rate = hyperpara['success_rate']
best_hyperpara = hyperpara
hyperpara_all.append(hyperpara)
num_epochs = 50
print_every = 100000
# resnet
inplanes = 96
layers=[4, 4]
layers_planes=[128, 128]
growth_rate = 8
#dense net
inplanes = 64
layers=[3, 3]
layers_planes=[16, 32]
growth_rate = 16
# resnet, small
inplanes = 96
layers=[2, 3]
layers_planes=[128, 128]
growth_rate = 8
print(perf_aif_dataset.aif[0].shape)
C, H, W = perf_aif_dataset.aif[0].shape
model = models.GadgetronResUnet18(F0=C,
inplanes=inplanes,
layers=layers,
layers_planes=layers_planes,
use_dropout=False,
p=0.5,
H=H, W=W, C=num_classes,
verbose=True)
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
print("model on multiple GPU ... ")
patience = 10
factor = 0.5
cooldown = 3
min_lr = 1e-7
weight_decay=0
learning_rate = 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay, amsgrad=False)
# optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay, nesterov=True)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience, cooldown=cooldown, min_lr=min_lr, verbose=True)
criterion = training.LossMulti(class_weights=class_weights, jaccard_weight=0.5)
# criterion = nn.BCEWithLogitsLoss()
# criterion = nn.BCELoss()
log_dir = './aif_training/ResUnet' + '_lr_' + str(learning_rate) + '_epochs_' + str(num_epochs)
writer = SummaryWriter(log_dir)
aif_trainer = training.GadgetronMultiClassSeg_Perf(model,
optimizer,
criterion,
loader_for_train,
loader_for_val,
class_for_accu=class_for_accu,
p_thres = p_thres,
scheduler=scheduler,
epochs=num_epochs,
device=device,
x_dtype=torch.float32,
y_dtype=torch.long,
early_stopping_thres = 100,
print_every=print_every,
small_data_mode = False,
writer=writer,
model_folder="aif_training/")
epochs_traning, epochs_validation, best_model, loss_all, epochs_acc_class = aif_trainer.train(verbose=True, epoch_to_load=-1, save_model_epoch=True)
#Original Plot from research paper
fig = plt.figure()
plt.plot(loss_all[0:500,0], loss_all[0:500,1])
#Plotly
fig = go.Figure(data=go.Scatter(x=loss_all[0:500,0], y=loss_all[0:500,1], name='Loss'))
plotly.offline.plot(fig, filename = 'figure_6.html')
display(HTML('figure_6.html'))
acc, loss, acc_class = aif_trainer.check_validation_test_accuracy(loader_for_val, best_model)
print(acc, loss)
print(acc_class)
# data augmenation for random flipping
transform = torchvision.transforms.Compose([RandomFlip1stDim(0.5), RandomFlip2ndDim(0.5)])
perf_aif_dataset.transform = transform
perf_aif_dataset.which_mask = 'lv'
num_classes = 1
p_thres = 0.5
print(perf_aif_dataset.which_mask)
sample = perf_aif_dataset[1]
print(sample[1].shape)
#Original Plot from research paper
plt.figure()
plt.imshow(np.squeeze(sample[1]))
#Plotly
#fig = px.imshow(np.squeeze(sample[1]), binary_string=True)
# load dataset
#df = pd.read_csv("https://raw.githubusercontent.com/plotly/datasets/master/volcano.csv")
# Create figure
fig = go.Figure()
# Add surface trace
fig.add_trace(go.Heatmap(z=np.squeeze(sample[1]), colorscale="Gray"))
# Update plot sizing
fig.update_layout(
width=800,
height=900,
autosize=False,
margin=dict(t=100, b=0, l=0, r=0),
)
# Update 3D scene options
fig.update_scenes(
aspectratio=dict(x=1, y=1, z=0.7),
aspectmode="manual"
)
# Add dropdowns
button_layer_1_height = 1.08
fig.update_layout(
updatemenus=[
dict(
buttons=list([
dict(
args=["colorscale", "Gray"],
label="Gray",
method="restyle"
),
dict(
args=["colorscale", "Cividis"],
label="Cividis",
method="restyle"
),
dict(
args=["colorscale", "Blues"],
label="Blues",
method="restyle"
),
dict(
args=["colorscale", "Viridis"],
label="Viridis",
method="restyle"
),
]),
direction="down",
pad={"r": 10, "t": 10},
showactive=True,
x=0.1,
xanchor="left",
y=button_layer_1_height,
yanchor="top"
),
dict(
buttons=list([
dict(
args=["reversescale", False],
label="False",
method="restyle"
),
dict(
args=["reversescale", True],
label="True",
method="restyle"
)
]),
direction="down",
pad={"r": 10, "t": 10},
showactive=True,
x=0.37,
xanchor="left",
y=button_layer_1_height,
yanchor="top"
),
]
)
fig.update_layout(
annotations=[
dict(text="colorscale", x=0, xref="paper", y=1.06, yref="paper",
align="left", showarrow=False),
dict(text="Reverse<br>Colorscale", x=0.25, xref="paper", y=1.07,
yref="paper", showarrow=False)
])
plotly.offline.plot(fig, filename = 'figure_7.html')
display(HTML('figure_7.html'))
images, masks, names = iter_train.next()
B, C, RO, E1 = images.shape
print(images.shape)
print(masks.shape)
print(torch.max(images))
print(torch.max(masks))
a = images[:,32,:,:]
a = torch.reshape(a, (B, 1, RO, E1))
plt.figure(figsize=(16, 16))
show(make_grid(a.double(), nrow=8, padding=2, normalize=False, scale_each=True))
plt.figure(figsize=(16, 16))
show(make_grid(masks.double(), nrow=8, padding=2, normalize=True, scale_each=False))
X = images.type(torch.FloatTensor)
y = masks.type(torch.FloatTensor)
print(X.shape)
print(y.shape)
num_epochs = 50
print_every = 100000
inplanes = 96
layers=[4, 4]
layers_planes=[128, 128]
print(perf_aif_dataset.aif[0].shape)
C, H, W = perf_aif_dataset.aif[0].shape
model = models.GadgetronResUnet18(F0=C,
inplanes=inplanes,
layers=layers,
layers_planes=layers_planes,
use_dropout=False,
p=0.5,
H=H, W=W, C=num_classes,
verbose=True)
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
print("model on multiple GPU ... ")
patience = 10
factor = 0.5
cooldown = 3
min_lr = 1e-7
weight_decay=0
learning_rate = 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay, amsgrad=False)
# optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay, nesterov=True)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience, cooldown=cooldown, min_lr=min_lr, verbose=True)
criterion = training.LossBinary(jaccard_weight=0.5)
# criterion = training.LossMulti(class_weights=class_weights, jaccard_weight=0.5)
log_dir = 'aif_training/ResUnet' + '_lr_' + str(learning_rate) + '_epochs_' + str(num_epochs)
writer = SummaryWriter(log_dir)
aif_trainer = training.GadgetronTwoClassSeg_PerfAIF(model,
optimizer,
criterion,
loader_for_train,
loader_for_val,
p_thres=p_thres,
scheduler=scheduler,
epochs=num_epochs,
device=device,
x_dtype=torch.float32,
y_dtype=torch.float32,
early_stopping_thres = 100,
print_every=print_every,
small_data_mode = False,
writer=writer,
model_folder="aif_training/")
epochs_traning, epochs_validation, best_model, loss_all, epochs_acc_class = aif_trainer.train(verbose=True, epoch_to_load=-1, save_model_epoch=True)
try:
best_model_cpu = best_model.cpu().module
except:
best_model_cpu = best_model.cpu()
print(best_model_cpu)
from datetime import date
today = str(date.today())
print(today)
from time import gmtime, strftime
moment = strftime("%Y%m%d_%H%M%S", gmtime())
print(moment)
model_file = './deployment/networks/perf_aif_' + perf_aif_dataset.which_mask + '_network_' + moment + '.pbt'
print(model_file)
torch.save(best_model_cpu, model_file)
model_loaded = torch.load(model_file)
print(model_loaded)