Commit 4486ce8d authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Handle modality in a single network

parent 7c0eac2a
......@@ -114,127 +114,19 @@ class Ciompi(nn.Module):
nn.init.constant_(m.bias, 0)
class CiompiDO(nn.Module):
def __init__(self, n_classes=2, n_channels=2, dropout=0.5):
def __init__(self, n_classes=2, n_channels=2, modality='CT/PET', dropout=0.5):
assert modality in ['CT', 'PET', 'CT/PET']
if modality == 'CT/PET':
assert n_channels == 2
else:
assert n_channels == 1:
super(CiompiDO, self).__init__()
self.n_classes = n_classes
self.n_channels = n_channels
self.dropout = dropout
self.CT_branch = nn.Sequential( #64x64x64
nn.BatchNorm3d(1),
nn.Conv3d(1, 32, 5), #64-5+1 = 60x60x60
nn.BatchNorm3d(32),
nn.ReLU(),
nn.MaxPool3d(2), #30x30x30
nn.Dropout3d(self.dropout),
nn.Conv3d(32, 64, 3), #30-3+1 = 28x28x28
nn.BatchNorm3d(64),
nn.ReLU(),
nn.Conv3d(64, 64, 3), #28-3+1 = 26x26x26
nn.BatchNorm3d(64),
nn.ReLU(),
nn.MaxPool3d(2), #13x13x13
nn.Dropout3d(self.dropout/2),
nn.Conv3d(64, 128, 3), #13-3+1 = 11x11x11
nn.BatchNorm3d(128),
nn.ReLU(),
nn.MaxPool3d(2), #5x5x5
nn.Dropout3d(self.dropout/2),
nn.Conv3d(128, 256, 3), #5-3+1 = 2x2x2
nn.BatchNorm3d(256),
nn.ReLU(),
nn.MaxPool3d(2), #1x1x1
nn.Dropout3d(self.dropout/4),
nn.AdaptiveAvgPool3d(1) #<- questo e' inutile per input di 64, ma se si passa a 128 meglio tenerlo o bisogna cambiare il primo layer lineare
)
self.PT_branch = nn.Sequential( #64x64x64
nn.BatchNorm3d(1),
nn.Conv3d(1, 32, 5), #64-5+1 = 60x60x60
nn.BatchNorm3d(32),
nn.ReLU(),
nn.MaxPool3d(2), #30x30x30
nn.Dropout3d(self.dropout),
nn.Conv3d(32, 64, 3), #30-3+1 = 28x28x28
nn.BatchNorm3d(64),
nn.ReLU(),
nn.Conv3d(64, 64, 3), #28-3+1 = 26x26x26
nn.BatchNorm3d(64),
nn.ReLU(),
nn.MaxPool3d(2), #13x13x13
nn.Dropout3d(self.dropout/2),
nn.Conv3d(64, 128, 3), #13-3+1 = 11x11x11
nn.BatchNorm3d(128),
nn.ReLU(),
nn.MaxPool3d(2), #5x5x5
nn.Dropout3d(self.dropout/2),
nn.Conv3d(128, 256, 3), #5-3+1 = 2x2x2
nn.BatchNorm3d(256),
nn.ReLU(),
nn.MaxPool3d(2), #1x1x1
nn.Dropout3d(self.dropout/4),
nn.AdaptiveAvgPool3d(1) #<- questo e' inutile per input di 64, ma se si passa a 128 meglio tenerlo o bisogna cambiare il primo layer lineare
)
self.linear = nn.Sequential(
nn.Dropout(self.dropout),
nn.Linear(256*2, 50),
nn.ReLU(),
nn.Dropout(self.dropout/2),
nn.Linear(50, self.n_classes),
nn.Softmax(1)
)
def forward(self, x): #x = N_batch x N_ch * s x s x s
x_CT = x[:, 0, :,:,:].unsqueeze(1)#x = N_batch * s x s x s
features_CT = self.CT_branch(x_CT)
features_CT = features_CT.view(x.shape[0], -1) #n_batch * X
x_PT = x[:, 1, :,:,:].unsqueeze(1)
features_PT = self.PT_branch(x_PT)
features_PT = features_PT.view(x.shape[0], -1)
out_merged = torch.cat([features_CT, features_PT], dim=1)
out = self.linear(out_merged)
return(out)
def extract_features(self, x):
x_CT = x[:, 0, :,:,:].unsqueeze(1)
features_CT = self.CT_branch(x_CT)
features_CT = features_CT.view(x.shape[0], -1)
x_PT = x[:, 1, :,:,:].unsqueeze(1)
features_PT = self.PT_branch(x_PT)
features_PT = features_PT.view(x.shape[0], -1)
out_merged = torch.cat([features_CT, features_PT], dim=1)
return(out_merged)
self.modality = modality
def initialize_weights(self):
for m in self.modules():
#if isinstance(m, nn.Linear):
# nn.init.xavier_uniform_(m.weight)
# nn.init.constant_(m.bias, 0)
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if isinstance(m, nn.BatchNorm3d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class CiompiDO_CT(nn.Module):
def __init__(self, n_classes=2, n_channels=1, dropout=0.5):
super(CiompiDO_CT, self).__init__()
self.n_classes = n_classes
self.n_channels = n_channels
self.dropout = dropout
self.CT_branch = nn.Sequential( #64x64x64
nn.BatchNorm3d(1),
nn.Conv3d(1, 32, 5), #64-5+1 = 60x60x60
......@@ -266,49 +158,7 @@ class CiompiDO_CT(nn.Module):
nn.AdaptiveAvgPool3d(1) #<- questo e' inutile per input di 64, ma se si passa a 128 meglio tenerlo o bisogna cambiare il primo layer lineare
)
self.linear = nn.Sequential(
nn.Dropout(self.dropout),
nn.Linear(256, 50),
nn.ReLU(),
nn.Dropout(self.dropout/2),
nn.Linear(50, self.n_classes),
nn.Softmax(1)
)
def forward(self, x): #x = N_batch x N_ch * s x s x s
x_CT = x[:, 0, :,:,:].unsqueeze(1)#x = N_batch * s x s x s
features_CT = self.CT_branch(x_CT)
features_CT = features_CT.view(x.shape[0], -1) #n_batch * X
out = self.linear(features_CT)
return(out)
def extract_features(self, x):
x_CT = x[:, 0, :,:,:].unsqueeze(1)
features_CT = self.CT_branch(x_CT)
features_CT = features_CT.view(x.shape[0], -1)
return(features_CT)
def initialize_weights(self):
for m in self.modules():
#if isinstance(m, nn.Linear):
# nn.init.xavier_uniform_(m.weight)
# nn.init.constant_(m.bias, 0)
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if isinstance(m, nn.BatchNorm3d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class CiompiDO_PT(nn.Module):
def __init__(self, n_classes=2, n_channels=1, dropout=0.5):
super(CiompiDO_PT, self).__init__()
self.n_classes = n_classes
self.n_channels = n_channels
self.dropout = dropout
self.PT_branch = nn.Sequential( #64x64x64
nn.BatchNorm3d(1),
nn.Conv3d(1, 32, 5), #64-5+1 = 60x60x60
......@@ -343,7 +193,7 @@ class CiompiDO_PT(nn.Module):
self.linear = nn.Sequential(
nn.Dropout(self.dropout),
nn.Linear(256, 50),
nn.Linear(256*self.n_channels, 50),
nn.ReLU(),
nn.Dropout(self.dropout/2),
......@@ -351,219 +201,52 @@ class CiompiDO_PT(nn.Module):
nn.Softmax(1)
)
def forward(self, x): #x = N_batch x N_ch * s x s x s
x_PT = x[:, 1, :,:,:].unsqueeze(1)#x = N_batch * s x s x s
features_PT = self.PT_branch(x_PT)
features_PT = features_PT.view(x.shape[0], -1) #n_batch * X
out = self.linear(features_PT)
return(out)
def extract_features(self, x):
x_PT = x[:, 1, :,:,:].unsqueeze(1)
features_PT = self.PT_branch(x_PT)
features_PT = features_PT.view(x.shape[0], -1)
return(features_PT)
def initialize_weights(self):
for m in self.modules():
#if isinstance(m, nn.Linear):
# nn.init.xavier_uniform_(m.weight)
# nn.init.constant_(m.bias, 0)
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if isinstance(m, nn.BatchNorm3d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class CiompiDOSimple(nn.Module):
def __init__(self, n_classes=2, n_channels=2, dropout=0.5):
super(CiompiDOSimple, self).__init__()
self.n_classes = n_classes
self.n_channels = n_channels
self.dropout = dropout
self.CT_branch = nn.Sequential( #64x64x64
nn.BatchNorm3d(1),
nn.Conv3d(1, 10, 5), #64-5+1 = 60x60x60
nn.BatchNorm3d(10),
nn.ReLU(),
nn.Dropout3d(self.dropout),
nn.MaxPool3d(2), #30x30x30
nn.Conv3d(10, 20, 3), #30-3+1 = 28x28x28
nn.BatchNorm3d(20),
nn.ReLU(),
nn.Dropout3d(self.dropout/2),
nn.MaxPool3d(2), #13x13x13
nn.Conv3d(20, 40, 3), #13-3+1 = 11x11x11
nn.BatchNorm3d(40),
nn.ReLU(),
nn.Dropout3d(self.dropout/2),
nn.MaxPool3d(2), #5x5x5
#nn.Conv3d(40, 40, 3), #5-3+1 = 2x2x2
#nn.ReLU(),
#nn.Dropout(self.dropout/4),
#nn.MaxPool3d(2), #1x1x1
nn.AdaptiveAvgPool3d(1) #<- questo e' inutile per input di 64, ma se si passa a 128 meglio tenerlo o bisogna cambiare il primo layer lineare
)
self.PT_branch = nn.Sequential( #64x64x64
nn.BatchNorm3d(1),
nn.Conv3d(1, 10, 5), #64-5+1 = 60x60x60
nn.BatchNorm3d(10),
nn.ReLU(),
nn.Dropout3d(self.dropout),
nn.MaxPool3d(2), #30x30x30
nn.Conv3d(10, 20, 3), #30-3+1 = 28x28x28
nn.BatchNorm3d(20),
nn.ReLU(),
nn.Dropout3d(self.dropout/2),
nn.MaxPool3d(2), #13x13x13
nn.Conv3d(20, 40, 3), #13-3+1 = 11x11x11
nn.BatchNorm3d(40),
nn.ReLU(),
nn.Dropout3d(self.dropout/2),
nn.MaxPool3d(2), #5x5x5
def forward(self, x): #x = N_batch x N_ch * s x s x s
if self.modality == 'CT':
x_CT = x[:, 0, :,:,:].unsqueeze(1) # only CT volumes
features_CT = self.CT_branch(x_CT)
features_CT = features_CT.view(x.shape[0], -1) #n_batch * X
out = self.linear(features_CT)
#nn.Conv3d(40, 40, 3), #5-3+1 = 2x2x2
#nn.BatchNorm3d(40),
#nn.ReLU(),
#nn.Dropout(self.dropout/4),
#nn.MaxPool3d(2), #1x1x1
elif self.modality == 'PET':
x_PT = x[:, 1, :,:,:].unsqueeze(1) # only PET volumes
features_PT = self.PT_branch(x_PT)
features_PT = features_PT.view(x.shape[0], -1)
out = self.linear(features_PT)
nn.AdaptiveAvgPool3d(1) #<- questo e' inutile per input di 64, ma se si passa a 128 meglio tenerlo o bisogna cambiare il primo layer lineare
)
self.linear = nn.Sequential(
nn.Linear(40*2, 50),
nn.ReLU(),
nn.Dropout(self.dropout),
nn.Linear(50, self.n_classes),
nn.Softmax(1)
)
def forward(self, x): #x = N_batch x N_ch * s x s x s
x_CT = x[:, 0, :,:,:].unsqueeze(1)#x = N_batch * s x s x s
features_CT = self.CT_branch(x_CT)
features_CT = features_CT.view(x.shape[0], -1) #n_batch * X
x_PT = x[:, 1, :,:,:].unsqueeze(1)
features_PT = self.PT_branch(x_PT)
features_PT = features_PT.view(x.shape[0], -1)
out_merged = torch.cat([features_CT, features_PT], dim=1)
out = self.linear(out_merged)
return(out)
else:
x_CT = x[:, 0, :,:,:].unsqueeze(1)#x = N_batch * s x s x s
features_CT = self.CT_branch(x_CT)
features_CT = features_CT.view(x.shape[0], -1) #n_batch * X
x_PT = x[:, 1, :,:,:].unsqueeze(1)
features_PT = self.PT_branch(x_PT)
features_PT = features_PT.view(x.shape[0], -1)
out_merged = torch.cat([features_CT, features_PT], dim=1)
out = self.linear(out_merged)
return out
def extract_features(self, x):
x_CT = x[:, 0, :,:,:].unsqueeze(1)
features_CT = self.CT_branch(x_CT)
features_CT = features_CT.view(x.shape[0], -1)
x_PT = x[:, 1, :,:,:].unsqueeze(1)
features_PT = self.PT_branch(x_PT)
features_PT = features_PT.view(x.shape[0], -1)
out_merged = torch.cat([features_CT, features_PT], dim=1)
return(out_merged)
def initialize_weights(self):
for m in self.modules():
#if isinstance(m, nn.Linear):
# nn.init.xavier_uniform_(m.weight)
# nn.init.constant_(m.bias, 0)
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if isinstance(m, nn.BatchNorm3d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class CiompiSimple(nn.Module):
def __init__(self):
super(CiompiSimple, self).__init__()
self.n_classes = 2
self.n_channels = 2
#dropout = 0.5
self.CT_branch = nn.Sequential( #64x64x64
nn.BatchNorm3d(1),
nn.Conv3d(1, 32, 5), #64-5+1 = 60x60x60
nn.BatchNorm3d(32),
nn.ReLU(),
nn.MaxPool3d(2), #30x30x30
nn.Conv3d(32, 64, 3), #30-3+1 = 28x28x28
nn.BatchNorm3d(64),
nn.Conv3d(64, 64, 3), #28-3+1 = 26x26x26
nn.BatchNorm3d(64),
nn.ReLU(),
nn.MaxPool3d(2), #13x13x13
nn.Conv3d(64, 128, 3), #13-3+1 = 11x11x11
nn.BatchNorm3d(128),
nn.ReLU(),
nn.MaxPool3d(2), #5x5x5
nn.AdaptiveAvgPool3d(1)
)
self.PT_branch = nn.Sequential(
nn.BatchNorm3d(1),
nn.Conv3d(1, 32, 5), #64-5+1 = 60x60x60
nn.BatchNorm3d(32),
nn.ReLU(),
nn.MaxPool3d(2), #30x30x30
nn.Conv3d(32, 64, 3), #30-3+1 = 28x28x28
nn.BatchNorm3d(64),
nn.Conv3d(64, 64, 3), #28-3+1 = 26x26x26
nn.BatchNorm3d(64),
nn.ReLU(),
nn.MaxPool3d(2), #13x13x13
if self.modality == 'CT':
x_CT = x[:, 0, :,:,:].unsqueeze(1) # only CT volumes
features_CT = self.CT_branch(x_CT)
out = features_CT.view(x.shape[0], -1) #n_batch * X
nn.Conv3d(64, 128, 3), #13-3+1 = 11x11x11
nn.BatchNorm3d(128),
nn.ReLU(),
nn.MaxPool3d(2), #5x5x5
elif self.modality == 'PET':
x_PT = x[:, 1, :,:,:].unsqueeze(1) # only PET volumes
features_PT = self.PT_branch(x_PT)
out = features_PT.view(x.shape[0], -1)
nn.AdaptiveAvgPool3d(1)
)
self.linear = nn.Sequential(
nn.Linear(128*2, 50),
nn.ReLU(),
nn.Linear(50, self.n_classes),
nn.Softmax(1)
)
def forward(self, x): #x = N_batch x N_ch * s x s x s
x_CT = x[:, 0, :,:,:].unsqueeze(1)#x = N_batch * s x s x s
features_CT = self.CT_branch(x_CT)
features_CT = features_CT.view(x.shape[0], -1) #n_batch * X
x_PT = x[:, 1, :,:,:].unsqueeze(1)
features_PT = self.PT_branch(x_PT)
features_PT = features_PT.view(x.shape[0], -1)
out_merged = torch.cat([features_CT, features_PT], dim=1)
out = self.linear(out_merged)
return(out)
def extract_features(self, x):
x_CT = x[:, 0, :,:,:].unsqueeze(1)
features_CT = self.CT_branch(x_CT)
features_CT = features_CT.view(x.shape[0], -1)
x_PT = x[:, 1, :,:,:].unsqueeze(1)
features_PT = self.PT_branch(x_PT)
features_PT = features_PT.view(x.shape[0], -1)
out_merged = torch.cat([x_CT, x_PT], dim=1)
return(out_merged)
else:
x_CT = x[:, 0, :,:,:].unsqueeze(1)#x = N_batch * s x s x s
features_CT = self.CT_branch(x_CT)
features_CT = features_CT.view(x.shape[0], -1) #n_batch * X
x_PT = x[:, 1, :,:,:].unsqueeze(1)
features_PT = self.PT_branch(x_PT)
features_PT = features_PT.view(x.shape[0], -1)
out = torch.cat([features_CT, features_PT], dim=1)
return out
def initialize_weights(self):
for m in self.modules():
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment