Commit acaee3a5 authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Fix typo + black formatting

parent 15da5952
......@@ -6,29 +6,38 @@ from DP_classification.models.base_model import BaseModel
class DAPPER_model(nn.Module):
def __init__(self, n_features_out, adapter_out_features=1000, version='', pretrained=False, n_classes=2, dropout=0.5):
def __init__(
self,
n_features_out,
adapter_out_features=1000,
version='',
pretrained=False,
n_classes=2,
dropout=0.5,
):
super(DAPPER_model, self).__init__()
self.adapter = nn.Sequential(nn.Linear(n_features_out, adapter_out_features),
nn.ReLU(True),
nn.Dropout(dropout))
self.fc_moba = nn.Sequential(nn.BatchNorm1d(1000),
nn.Linear(1000, 1000),
nn.ReLU(True),
nn.Dropout(dropout),
nn.BatchNorm1d(1000),
nn.Linear(1000, 256),
nn.ReLU(True),
nn.Dropout(dropout))
self.fc_final = nn.Sequential(nn.BatchNorm1d(256),
nn.Linear(256, n_classes))
self.adapter = nn.Sequential(
nn.Linear(n_features_out, adapter_out_features),
nn.ReLU(True),
nn.Dropout(dropout),
)
self.fc_moba = nn.Sequential(
nn.BatchNorm1d(1000),
nn.Linear(1000, 1000),
nn.ReLU(True),
nn.Dropout(dropout),
nn.BatchNorm1d(1000),
nn.Linear(1000, 256),
nn.ReLU(True),
nn.Dropout(dropout),
)
self.fc_final = nn.Sequential(nn.BatchNorm1d(256), nn.Linear(256, n_classes))
def forward(self, x):
features = self.adapter(features)
features = self.adapter(x)
features = self.fc_moba(features)
out = self.fc_final(features)
return out
......@@ -39,31 +48,48 @@ class DAPPER_model(nn.Module):
class ResNet_model(DAPPER_model):
def __init__(self, version='152', pretrained=False, n_classes=2, dropout=0.5):
assert version in ['101', '152', '18', '34', '50'], "ResNet version not supported"
assert version in [
'101',
'152',
'18',
'34',
'50',
], "ResNet version not supported"
features_extractors, resnet_n_features_out = self._get_conv_layers(version, pretrained)
features_extractors, resnet_n_features_out = self._get_conv_layers(
version, pretrained
)
super(ResNet_model, self).__init__(resnet_n_features_out, 1000, version, pretrained, n_classes, dropout)
super(ResNet_model, self).__init__(
resnet_n_features_out, 1000, version, pretrained, n_classes, dropout
)
self.features_extractors = features_extractors
def _get_conv_layers(self, version='152', pretrained=False):
resnet_init = getattr(models, f'resnet{version}')
resnet = resnet_init(pretrained)
features_extractors = nn.Sequential(resnet.conv1, resnet.bn1,resnet.relu, resnet.maxpool,
resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4,
resnet.avgpool)
resnet_n_features_out = resnet.layer1[0].expansion*512
features_extractors = nn.Sequential(
resnet.conv1,
resnet.bn1,
resnet.relu,
resnet.maxpool,
resnet.layer1,
resnet.layer2,
resnet.layer3,
resnet.layer4,
resnet.avgpool,
)
resnet_n_features_out = resnet.layer1[0].expansion * 512
return features_extractors, resnet_n_features_out
def forward(self, x):
features = self.features_extractors(x)
features = features.view(features.size(0), -1)
out = super().forward(x)
out = super().forward(features)
return out
def test(self):
super().test() # run the forward pass
super().test() # run the forward pass
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment